Global de-std::unique_ptr cleanup for xla::Literal.
PiperOrigin-RevId: 212313258
This commit is contained in:
parent
656b3e9c84
commit
dd6d7c5c58
tensorflow/compiler
tf2xla
xla
client
literal.ccliteral.hliteral_test.ccliteral_util.ccliteral_util.hpacked_literal_reader.ccpacked_literal_reader.hpython
local_computation_builder.cclocal_computation_builder.hlocal_computation_builder.inumpy_bridge.ccnumpy_bridge.h
reference_util.ccreference_util_test.ccrpc
service
algebraic_simplifier.ccalgebraic_simplifier_test.ccbatchnorm_expander.ccbfloat16_propagation_test.ccbuffer_assignment_test.ccbuffer_liveness_test.ccconvolution_feature_group_converter.cc
cpu/tests
elemental_ir_emitter_test.ccgeneric_transfer_manager.ccgpu
hlo_constant_folding.cchlo_constant_folding_test.cchlo_creation_utils.cchlo_creation_utils_test.cchlo_cse_test.cchlo_evaluator.cchlo_evaluator.hhlo_evaluator_test.cchlo_evaluator_typed_visitor.hhlo_instruction.cchlo_instruction.hhlo_instructions.cchlo_instructions.hhlo_parser.cchlo_runner.cchlo_runner.hhlo_verifier_test.ccindexed_array_analysis.ccindexed_array_analysis.hinliner_test.ccinterpreter
layout_assignment_test.ccservice.cctransfer_manager.cctransfer_manager.htuple_points_to_analysis_test.ccwhile_loop_analysis.cctests
array_elementwise_ops_test.ccbatch_normalization_test.ccbfloat16_test.ccbroadcast_simple_test.ccbroadcast_test.cccall_test.cccheck_execution_arity_test.ccclient_library_test_base.ccclient_library_test_base.hclient_test.cccompilation_cache_test.cccompute_constant_test.ccconcat_test.ccconditional_test.ccconstants_test.ccconvert_test.ccconvolution_dimension_numbers_test.ccconvolution_test.ccconvolution_variants_test.cccopy_test.cccross_replica_sum_test.cccustom_call_test.ccdeconstruct_tuple_test.ccdot_operation_test.ccdynamic_ops_test.ccexecution_profile_test.ccexhaustive_f32_elementwise_op_test.cc
@ -81,7 +81,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
|
||||
TF_ASSIGN_OR_RETURN(auto literal,
|
||||
client->ComputeConstant(constant_graph));
|
||||
TF_RETURN_IF_ERROR(
|
||||
LiteralToHostTensor(*literal, arg.type, &arg.constant_value));
|
||||
LiteralToHostTensor(literal, arg.type, &arg.constant_value));
|
||||
} else {
|
||||
arg.kind = XlaCompiler::Argument::kParameter;
|
||||
}
|
||||
|
@ -78,14 +78,14 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
|
||||
std::vector<xla::XlaOp> args;
|
||||
args.push_back(ctx->Input(0));
|
||||
args.push_back(xla::ConstantLiteral(
|
||||
&b, *xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
|
||||
&b, xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
|
||||
if (input_shape.dims() > 1) {
|
||||
// Don't bother passing the output shape and dim for the 1d case, since
|
||||
// the shape is always a scalar and the dim is always 0.
|
||||
args.push_back(xla::ConstantLiteral(
|
||||
&b, *xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
|
||||
&b, xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
|
||||
args.push_back(
|
||||
xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0<int32>(dim)));
|
||||
xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0<int32>(dim)));
|
||||
}
|
||||
|
||||
xla::Shape xla_shape =
|
||||
|
@ -64,31 +64,31 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
|
||||
xla::Literal literal;
|
||||
switch (type) {
|
||||
case xla::U8:
|
||||
literal = std::move(*xla::LiteralUtil::CreateR0<uint8>(value));
|
||||
literal = xla::LiteralUtil::CreateR0<uint8>(value);
|
||||
break;
|
||||
case xla::U32:
|
||||
literal = std::move(*xla::LiteralUtil::CreateR0<uint32>(value));
|
||||
literal = xla::LiteralUtil::CreateR0<uint32>(value);
|
||||
break;
|
||||
case xla::U64:
|
||||
literal = std::move(*xla::LiteralUtil::CreateR0<uint64>(value));
|
||||
literal = xla::LiteralUtil::CreateR0<uint64>(value);
|
||||
break;
|
||||
case xla::S8:
|
||||
literal = std::move(*xla::LiteralUtil::CreateR0<int8>(value));
|
||||
literal = xla::LiteralUtil::CreateR0<int8>(value);
|
||||
break;
|
||||
case xla::S32:
|
||||
literal = std::move(*xla::LiteralUtil::CreateR0<int32>(value));
|
||||
literal = xla::LiteralUtil::CreateR0<int32>(value);
|
||||
break;
|
||||
case xla::S64:
|
||||
literal = std::move(*xla::LiteralUtil::CreateR0<int64>(value));
|
||||
literal = xla::LiteralUtil::CreateR0<int64>(value);
|
||||
break;
|
||||
case xla::F32:
|
||||
literal = std::move(*xla::LiteralUtil::CreateR0<float>(value));
|
||||
literal = xla::LiteralUtil::CreateR0<float>(value);
|
||||
break;
|
||||
case xla::F64:
|
||||
literal = std::move(*xla::LiteralUtil::CreateR0<double>(value));
|
||||
literal = xla::LiteralUtil::CreateR0<double>(value);
|
||||
break;
|
||||
case xla::C64:
|
||||
literal = std::move(*xla::LiteralUtil::CreateR0<complex64>(value));
|
||||
literal = xla::LiteralUtil::CreateR0<complex64>(value);
|
||||
break;
|
||||
case xla::PRED:
|
||||
LOG(FATAL) << "pred element type is not integral";
|
||||
@ -96,12 +96,12 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
|
||||
case xla::U16:
|
||||
LOG(FATAL) << "u16/s16 literals not yet implemented";
|
||||
case xla::BF16:
|
||||
literal = std::move(
|
||||
*xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value)));
|
||||
literal =
|
||||
xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value));
|
||||
break;
|
||||
case xla::F16:
|
||||
literal = std::move(*xla::LiteralUtil::CreateR0<xla::half>(
|
||||
static_cast<xla::half>(value)));
|
||||
literal =
|
||||
xla::LiteralUtil::CreateR0<xla::half>(static_cast<xla::half>(value));
|
||||
break;
|
||||
case xla::TUPLE:
|
||||
LOG(FATAL) << "tuple element type is not integral";
|
||||
|
@ -27,19 +27,17 @@ TEST(LiteralUtil, LiteralToHostTensor) {
|
||||
// int64 literal can only be converted to an int64 host tensor.
|
||||
{
|
||||
std::vector<int64> int64_values = {1, 2, 3};
|
||||
std::unique_ptr<xla::Literal> int64_values_literal =
|
||||
xla::Literal int64_values_literal =
|
||||
xla::LiteralUtil::CreateR1(absl::Span<const int64>(int64_values));
|
||||
Tensor host_tensor;
|
||||
EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
|
||||
LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor)
|
||||
LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor)
|
||||
.error_message());
|
||||
EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32",
|
||||
LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor)
|
||||
.error_message());
|
||||
EXPECT_EQ(
|
||||
"Cannot convert literal of type S64 to tensor of type qint32",
|
||||
LiteralToHostTensor(*int64_values_literal, DT_QINT32, &host_tensor)
|
||||
.error_message());
|
||||
EXPECT_TRUE(
|
||||
LiteralToHostTensor(*int64_values_literal, DT_INT64, &host_tensor)
|
||||
.ok());
|
||||
LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok());
|
||||
test::ExpectTensorEqual<int64>(host_tensor,
|
||||
test::AsTensor<int64>(int64_values));
|
||||
}
|
||||
@ -48,23 +46,22 @@ TEST(LiteralUtil, LiteralToHostTensor) {
|
||||
// Repeat tests with int32.
|
||||
Tensor host_tensor;
|
||||
std::vector<int32> int32_values = {10, 11};
|
||||
std::unique_ptr<xla::Literal> int32_values_literal =
|
||||
xla::Literal int32_values_literal =
|
||||
xla::LiteralUtil::CreateR1(absl::Span<const int32>(int32_values));
|
||||
EXPECT_TRUE(
|
||||
LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor)
|
||||
.ok());
|
||||
LiteralToHostTensor(int32_values_literal, DT_INT32, &host_tensor).ok());
|
||||
test::ExpectTensorEqual<int32>(host_tensor,
|
||||
test::AsTensor<int32>(int32_values));
|
||||
|
||||
EXPECT_TRUE(
|
||||
LiteralToHostTensor(*int32_values_literal, DT_QINT32, &host_tensor)
|
||||
LiteralToHostTensor(int32_values_literal, DT_QINT32, &host_tensor)
|
||||
.ok());
|
||||
std::vector<qint32> qint32_values = {10, 11};
|
||||
test::ExpectTensorEqual<qint32>(host_tensor,
|
||||
test::AsTensor<qint32>(qint32_values));
|
||||
|
||||
EXPECT_EQ("Cannot convert literal of type S32 to tensor of type int64",
|
||||
LiteralToHostTensor(*int32_values_literal, DT_INT64, &host_tensor)
|
||||
LiteralToHostTensor(int32_values_literal, DT_INT64, &host_tensor)
|
||||
.error_message());
|
||||
}
|
||||
}
|
||||
|
@ -77,8 +77,8 @@ TEST(ConvertGraphDefToXla, Sum) {
|
||||
// Set up arguments.
|
||||
auto x_literal = xla::LiteralUtil::CreateR0<int32>(10);
|
||||
auto y_literal = xla::LiteralUtil::CreateR0<int32>(32);
|
||||
auto x_global_or = client->TransferToServer(*x_literal);
|
||||
auto y_global_or = client->TransferToServer(*y_literal);
|
||||
auto x_global_or = client->TransferToServer(x_literal);
|
||||
auto y_global_or = client->TransferToServer(y_literal);
|
||||
TF_EXPECT_OK(x_global_or.status());
|
||||
TF_EXPECT_OK(y_global_or.status());
|
||||
std::unique_ptr<xla::GlobalData> x_global =
|
||||
@ -90,8 +90,8 @@ TEST(ConvertGraphDefToXla, Sum) {
|
||||
auto result_or =
|
||||
client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()});
|
||||
TF_EXPECT_OK(result_or.status());
|
||||
std::unique_ptr<xla::Literal> result = std::move(result_or.ValueOrDie());
|
||||
EXPECT_EQ("(s32[]) (\n42\n)", result->ToString());
|
||||
xla::Literal result = std::move(result_or.ValueOrDie());
|
||||
EXPECT_EQ("(s32[]) (\n42\n)", result.ToString());
|
||||
|
||||
config.mutable_feed(0)->mutable_id()->set_output_index(
|
||||
123); /* invalid output_index */
|
||||
|
@ -208,27 +208,22 @@ TEST_F(XlaCompilerTest, Simple) {
|
||||
std::move(graph), args, &result));
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
std::unique_ptr<xla::Literal> param1_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::GlobalData> param1_data =
|
||||
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(param1_literal).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::GlobalData> actual =
|
||||
client_
|
||||
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
|
||||
.ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::Literal> actual_literal =
|
||||
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> expected0 =
|
||||
xla::LiteralUtil::CreateR1<int32>({4, 143});
|
||||
std::unique_ptr<xla::Literal> expected_literal =
|
||||
xla::LiteralUtil::MakeTuple({expected0.get()});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||
xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({4, 143});
|
||||
xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
|
||||
}
|
||||
|
||||
// Tests compilation of a graph where the _Retval node is not necessarily last
|
||||
@ -264,23 +259,20 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) {
|
||||
args, &result));
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
std::unique_ptr<xla::Literal> param1_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::GlobalData> param1_data =
|
||||
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(param1_literal).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::GlobalData> actual =
|
||||
client_
|
||||
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
|
||||
.ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::Literal> actual_literal =
|
||||
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal));
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal));
|
||||
}
|
||||
|
||||
// Tests that the compiler doesn't reorder the parameters.
|
||||
@ -408,23 +400,19 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
|
||||
EXPECT_FALSE(result.outputs[1].is_constant);
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::GlobalData> actual =
|
||||
client_->Execute(*result.computation, {param0_data.get()})
|
||||
.ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::Literal> actual_literal =
|
||||
xla::Literal actual_literal =
|
||||
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> expected0 =
|
||||
xla::LiteralUtil::CreateR1<int32>({-7, -42});
|
||||
std::unique_ptr<xla::Literal> expected_literal =
|
||||
xla::LiteralUtil::MakeTuple({expected0.get()});
|
||||
EXPECT_TRUE(
|
||||
xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||
xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
|
||||
xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
|
||||
}
|
||||
|
||||
{
|
||||
@ -443,24 +431,21 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
|
||||
EXPECT_FALSE(result.outputs[1].is_constant);
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::GlobalData> actual =
|
||||
client_->Execute(*result.computation, {param0_data.get()})
|
||||
.ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::Literal> actual_literal =
|
||||
xla::Literal actual_literal =
|
||||
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> expected0 =
|
||||
xla::LiteralUtil::CreateR0<int32>(7);
|
||||
std::unique_ptr<xla::Literal> expected1 =
|
||||
xla::LiteralUtil::CreateR1<int32>({-7, -42});
|
||||
std::unique_ptr<xla::Literal> expected =
|
||||
xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal));
|
||||
xla::Literal expected0 = xla::LiteralUtil::CreateR0<int32>(7);
|
||||
xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
|
||||
xla::Literal expected =
|
||||
xla::LiteralUtil::MakeTuple({&expected0, &expected1});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal));
|
||||
}
|
||||
}
|
||||
|
||||
@ -672,34 +657,26 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
|
||||
update.tensor_array_gradients_accessed);
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> input_base =
|
||||
xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
std::unique_ptr<xla::Literal> input_grad2 =
|
||||
xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||
std::unique_ptr<xla::Literal> input =
|
||||
xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()});
|
||||
xla::Literal input_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
xla::Literal input_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||
xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client_->TransferToServer(*input).ConsumeValueOrDie();
|
||||
client_->TransferToServer(input).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::GlobalData> actual =
|
||||
client_->Execute(*result.computation, {param0_data.get()})
|
||||
.ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::Literal> actual_literal =
|
||||
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> output_read =
|
||||
xla::LiteralUtil::CreateR0<int32>(42);
|
||||
std::unique_ptr<xla::Literal> output_base =
|
||||
xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
std::unique_ptr<xla::Literal> output_grad1 =
|
||||
xla::LiteralUtil::CreateR1<int32>({0, 1});
|
||||
std::unique_ptr<xla::Literal> output_grad2 =
|
||||
xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||
std::unique_ptr<xla::Literal> output_resource = xla::LiteralUtil::MakeTuple(
|
||||
{output_base.get(), output_grad1.get(), output_grad2.get()});
|
||||
std::unique_ptr<xla::Literal> expected_literal =
|
||||
xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||
xla::Literal output_read = xla::LiteralUtil::CreateR0<int32>(42);
|
||||
xla::Literal output_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
xla::Literal output_grad1 = xla::LiteralUtil::CreateR1<int32>({0, 1});
|
||||
xla::Literal output_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||
xla::Literal output_resource =
|
||||
xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2});
|
||||
xla::Literal expected_literal =
|
||||
xla::LiteralUtil::MakeTuple({&output_read, &output_resource});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
|
||||
}
|
||||
|
||||
// Tests compilation and execution of a graph that adds two tensors.
|
||||
@ -866,29 +843,24 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
|
||||
|
||||
void RunAndCheckVariablesComputation(
|
||||
xla::Client* client, const XlaCompiler::CompilationResult& result) {
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
std::unique_ptr<xla::Literal> param1_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
client->TransferToServer(param0_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::GlobalData> param1_data =
|
||||
client->TransferToServer(*param1_literal).ConsumeValueOrDie();
|
||||
client->TransferToServer(param1_literal).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::GlobalData> actual =
|
||||
client
|
||||
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
|
||||
.ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::Literal> actual_literal =
|
||||
client->Transfer(*actual).ConsumeValueOrDie();
|
||||
xla::Literal actual_literal = client->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> expected0 =
|
||||
xla::LiteralUtil::CreateR1<int32>({5, 144});
|
||||
std::unique_ptr<xla::Literal> expected1 =
|
||||
xla::LiteralUtil::CreateR1<int32>({4, 143});
|
||||
std::unique_ptr<xla::Literal> expected_literal =
|
||||
xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||
xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({5, 144});
|
||||
xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({4, 143});
|
||||
xla::Literal expected_literal =
|
||||
xla::LiteralUtil::MakeTuple({&expected0, &expected1});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
|
||||
}
|
||||
|
||||
// Tests a simple graph that reads and writes a variable.
|
||||
@ -952,20 +924,17 @@ TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) {
|
||||
std::move(graph), args, &result));
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> param1_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||
xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||
std::unique_ptr<xla::GlobalData> param1_data =
|
||||
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(param1_literal).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::GlobalData> actual =
|
||||
client_->Execute(*result.computation, {param1_data.get()})
|
||||
.ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::Literal> actual_literal =
|
||||
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> expected_literal =
|
||||
xla::LiteralUtil::MakeTuple({});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||
xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
|
||||
}
|
||||
|
||||
TEST_F(XlaCompilerTest, ReturnResourceHandle) {
|
||||
@ -1069,29 +1038,27 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
|
||||
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
xla::Literal param0_literal =
|
||||
xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}});
|
||||
std::unique_ptr<xla::Literal> param1_literal =
|
||||
xla::Literal param1_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::GlobalData> param1_data =
|
||||
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(param1_literal).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::GlobalData> actual =
|
||||
client_
|
||||
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
|
||||
.ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::Literal> actual_literal =
|
||||
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> expected0 =
|
||||
xla::Literal expected0 =
|
||||
xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}});
|
||||
std::unique_ptr<xla::Literal> expected1 =
|
||||
xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
|
||||
std::unique_ptr<xla::Literal> expected_literal =
|
||||
xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||
xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
|
||||
xla::Literal expected_literal =
|
||||
xla::LiteralUtil::MakeTuple({&expected0, &expected1});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
|
||||
}
|
||||
|
||||
TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
|
||||
@ -1138,29 +1105,26 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
|
||||
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
xla::Literal param0_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3});
|
||||
std::unique_ptr<xla::Literal> param1_literal =
|
||||
xla::Literal param1_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::GlobalData> param1_data =
|
||||
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(param1_literal).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::GlobalData> actual =
|
||||
client_
|
||||
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
|
||||
.ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::Literal> actual_literal =
|
||||
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> expected0 =
|
||||
xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
|
||||
std::unique_ptr<xla::Literal> expected1 =
|
||||
xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
|
||||
std::unique_ptr<xla::Literal> expected_literal =
|
||||
xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||
xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
|
||||
xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
|
||||
xla::Literal expected_literal =
|
||||
xla::LiteralUtil::MakeTuple({&expected0, &expected1});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
|
||||
}
|
||||
|
||||
// Tests a graph which has a function with an invalid op.
|
||||
|
@ -213,16 +213,15 @@ Status XlaOpKernelContext::ConstantInputReshaped(
|
||||
context_->op_kernel().name(), " input ", index,
|
||||
".\nError: ", constant_graph.status().error_message());
|
||||
}
|
||||
xla::StatusOr<std::unique_ptr<xla::Literal>> computed =
|
||||
compiler()->client()->ComputeConstant(constant_graph.ValueOrDie(),
|
||||
&layout);
|
||||
xla::StatusOr<xla::Literal> computed = compiler()->client()->ComputeConstant(
|
||||
constant_graph.ValueOrDie(), &layout);
|
||||
if (!computed.ok()) {
|
||||
return errors::Internal("Error evaluating ", context_->op_kernel().name(),
|
||||
" input ", index,
|
||||
" as a compile-time constant.\nError: ",
|
||||
computed.status().error_message());
|
||||
}
|
||||
*constant_literal = std::move(*computed.ValueOrDie());
|
||||
*constant_literal = std::move(computed).ValueOrDie();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -37,8 +37,8 @@ Client::Client(ServiceInterface* stub) : stub_(stub) {}
|
||||
|
||||
Client::~Client() = default;
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> Client::Transfer(
|
||||
const GlobalData& data, const Shape* shape_with_layout) {
|
||||
StatusOr<Literal> Client::Transfer(const GlobalData& data,
|
||||
const Shape* shape_with_layout) {
|
||||
TransferToClientRequest request;
|
||||
*request.mutable_data() = data.handle();
|
||||
if (shape_with_layout != nullptr) {
|
||||
@ -114,7 +114,7 @@ Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed(
|
||||
StatusOr<Literal> Client::TransferFromOutfeed(
|
||||
const Shape* shape_with_layout, int64 replica_id,
|
||||
const DeviceHandle* device_handle) {
|
||||
TransferFromOutfeedRequest request;
|
||||
@ -162,7 +162,7 @@ Status Client::ResetDevice() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
|
||||
StatusOr<Literal> Client::ExecuteAndTransfer(
|
||||
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
|
||||
const ExecutionOptions* execution_options,
|
||||
ExecutionProfile* execution_profile) {
|
||||
@ -177,8 +177,8 @@ StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
|
||||
return Transfer(*data, shape_with_output_layout);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> Client::ComputeConstant(
|
||||
const XlaComputation& computation, const Layout* output_layout) const {
|
||||
StatusOr<Literal> Client::ComputeConstant(const XlaComputation& computation,
|
||||
const Layout* output_layout) const {
|
||||
ComputeConstantGraphRequest request;
|
||||
*request.mutable_computation() = computation.proto();
|
||||
if (output_layout != nullptr) {
|
||||
|
@ -96,8 +96,8 @@ class Client {
|
||||
//
|
||||
// If shape_with_layout is not nullptr, it points to a shape whose layout will
|
||||
// be the layout of the returned literal.
|
||||
StatusOr<std::unique_ptr<Literal>> Transfer(
|
||||
const GlobalData& data, const Shape* shape_with_layout = nullptr);
|
||||
StatusOr<Literal> Transfer(const GlobalData& data,
|
||||
const Shape* shape_with_layout = nullptr);
|
||||
|
||||
// Transfer the given literal to the server. This allocates memory on the
|
||||
// device and copies the literal's contents over. Returns a global data handle
|
||||
@ -122,7 +122,7 @@ class Client {
|
||||
// device_handle and replica_id together specify a particular device; a device
|
||||
// assigned for the given replica_id among the replicas that the given device
|
||||
// handle belongs to.
|
||||
StatusOr<std::unique_ptr<Literal>> TransferFromOutfeed(
|
||||
StatusOr<Literal> TransferFromOutfeed(
|
||||
const Shape* shape_with_layout, int64 replica_id = 0,
|
||||
const DeviceHandle* device_handle = nullptr);
|
||||
|
||||
@ -132,7 +132,7 @@ class Client {
|
||||
// Executes the computation with the given arguments and transfers the result
|
||||
// to the client as a literal. Parameters are defined the same as for
|
||||
// Execute() and Transfer().
|
||||
StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
|
||||
StatusOr<Literal> ExecuteAndTransfer(
|
||||
const XlaComputation& computation,
|
||||
absl::Span<GlobalData* const> arguments,
|
||||
const ExecutionOptions* execution_options = nullptr,
|
||||
@ -153,7 +153,7 @@ class Client {
|
||||
//
|
||||
// If output_layout is non-null, then the output of the computation will be
|
||||
// stored using that layout.
|
||||
StatusOr<std::unique_ptr<Literal>> ComputeConstant(
|
||||
StatusOr<Literal> ComputeConstant(
|
||||
const XlaComputation& computation,
|
||||
const Layout* output_layout = nullptr) const;
|
||||
|
||||
|
@ -76,7 +76,7 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
|
||||
std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
|
||||
Client* client) {
|
||||
if (DataSizeOfShape(shape) < (1LL << 20)) {
|
||||
StatusOr<std::unique_ptr<Literal>> literal_status = MakeFakeLiteral(shape);
|
||||
StatusOr<Literal> literal_status = MakeFakeLiteral(shape);
|
||||
if (!literal_status.ok()) {
|
||||
// If we got an Unimplemented error, fall back to making the fake data via
|
||||
// an on-device computation.
|
||||
@ -84,7 +84,7 @@ std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
|
||||
tensorflow::error::UNIMPLEMENTED);
|
||||
return MakeFakeDataViaDeviceOrDie(shape, client);
|
||||
}
|
||||
return client->TransferToServer(*literal_status.ValueOrDie()).ValueOrDie();
|
||||
return client->TransferToServer(literal_status.ValueOrDie()).ValueOrDie();
|
||||
}
|
||||
|
||||
// If the data is large, generate it on-device.
|
||||
|
@ -195,9 +195,8 @@ Status LocalExecutable::RecordArguments(
|
||||
HloSnapshot* hlo_snapshot) {
|
||||
hlo_snapshot->clear_arguments();
|
||||
for (const ShapedBuffer* argument : arguments) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
|
||||
LiteralFromShapedBuffer(*argument));
|
||||
*hlo_snapshot->add_arguments() = literal->ToProto();
|
||||
TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*argument));
|
||||
*hlo_snapshot->add_arguments() = literal.ToProto();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -205,13 +204,12 @@ Status LocalExecutable::RecordArguments(
|
||||
Status LocalExecutable::RecordResult(const ShapedBuffer* result,
|
||||
HloSnapshot* hlo_snapshot) {
|
||||
hlo_snapshot->clear_result();
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
|
||||
LiteralFromShapedBuffer(*result));
|
||||
*hlo_snapshot->mutable_result() = literal->ToProto();
|
||||
TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*result));
|
||||
*hlo_snapshot->mutable_result() = literal.ToProto();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> LocalExecutable::LiteralFromShapedBuffer(
|
||||
StatusOr<Literal> LocalExecutable::LiteralFromShapedBuffer(
|
||||
const ShapedBuffer& shaped_buffer) {
|
||||
TF_ASSIGN_OR_RETURN(auto stream,
|
||||
backend_->BorrowStream(shaped_buffer.device_ordinal()));
|
||||
@ -277,7 +275,7 @@ StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(
|
||||
return std::move(scoped_buffer);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral(
|
||||
StatusOr<Literal> LocalClient::ShapedBufferToLiteral(
|
||||
const ShapedBuffer& shaped_buffer) {
|
||||
TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream(
|
||||
shaped_buffer.device_ordinal()));
|
||||
@ -298,13 +296,13 @@ Status LocalClient::TransferToInfeedLocal(const Literal& literal,
|
||||
literal);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> LocalClient::TransferFromOutfeedLocal(
|
||||
const Shape& shape, int device_ordinal) {
|
||||
StatusOr<Literal> LocalClient::TransferFromOutfeedLocal(const Shape& shape,
|
||||
int device_ordinal) {
|
||||
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
|
||||
backend().stream_executor(device_ordinal));
|
||||
auto literal = Literal::CreateFromShape(shape);
|
||||
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
|
||||
executor, shape, literal.get()));
|
||||
executor, shape, &literal));
|
||||
return std::move(literal);
|
||||
}
|
||||
|
||||
|
@ -84,8 +84,7 @@ class LocalExecutable {
|
||||
Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot);
|
||||
|
||||
// Returns a literal containing the contents of the given ShapedBuffer.
|
||||
StatusOr<std::unique_ptr<Literal>> LiteralFromShapedBuffer(
|
||||
const ShapedBuffer& shaped_buffer);
|
||||
StatusOr<Literal> LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer);
|
||||
|
||||
// The ordinal of the device which this executable was compiled for. The
|
||||
// executable can run on all equivalent devices (as determined by
|
||||
@ -132,8 +131,7 @@ class LocalClient : public Client {
|
||||
|
||||
// Copy the data from the device contained in the given ShapedBuffer and
|
||||
// return as a Literal.
|
||||
StatusOr<std::unique_ptr<Literal>> ShapedBufferToLiteral(
|
||||
const ShapedBuffer& shaped_buffer);
|
||||
StatusOr<Literal> ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer);
|
||||
|
||||
// Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
|
||||
// as long as the handle is valid.
|
||||
@ -151,8 +149,8 @@ class LocalClient : public Client {
|
||||
// TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
|
||||
// not inherit from Client and there is no possibility of confusion with
|
||||
// Client::TransferFromOutfeed.
|
||||
StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocal(
|
||||
const Shape& shape, int device_ordinal);
|
||||
StatusOr<Literal> TransferFromOutfeedLocal(const Shape& shape,
|
||||
int device_ordinal);
|
||||
|
||||
// Returns the device ordinal that corresponds to the given replica number.
|
||||
//
|
||||
|
@ -738,7 +738,7 @@ void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
|
||||
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = ShapeUtil::MakeNil();
|
||||
*instr.mutable_literal() = LiteralUtil::CreateR1U8(tag)->ToProto();
|
||||
*instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto();
|
||||
return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
|
||||
});
|
||||
}
|
||||
|
@ -2112,12 +2112,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp XlaBuilder::ConstantR0(NativeT value) {
|
||||
return ConstantLiteral(*LiteralUtil::CreateR0<NativeT>(value));
|
||||
return ConstantLiteral(LiteralUtil::CreateR0<NativeT>(value));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp XlaBuilder::ConstantR1(absl::Span<const NativeT> values) {
|
||||
return ConstantLiteral(*LiteralUtil::CreateR1<NativeT>(values));
|
||||
return ConstantLiteral(LiteralUtil::CreateR1<NativeT>(values));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
@ -2129,44 +2129,44 @@ XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) {
|
||||
}
|
||||
|
||||
inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) {
|
||||
return ConstantLiteral(*LiteralUtil::CreateR1(values));
|
||||
return ConstantLiteral(LiteralUtil::CreateR1(values));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp XlaBuilder::ConstantR2(
|
||||
std::initializer_list<std::initializer_list<NativeT>> values) {
|
||||
return ConstantLiteral(*LiteralUtil::CreateR2<NativeT>(values));
|
||||
return ConstantLiteral(LiteralUtil::CreateR2<NativeT>(values));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array<NativeT>& values,
|
||||
const Layout& layout) {
|
||||
return ConstantLiteral(
|
||||
*LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
|
||||
LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp XlaBuilder::ConstantFromArray(const Array<NativeT>& values) {
|
||||
return ConstantLiteral(*LiteralUtil::CreateFromArray<NativeT>(values));
|
||||
return ConstantLiteral(LiteralUtil::CreateFromArray<NativeT>(values));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout(
|
||||
const Array2D<NativeT>& values, const Layout& layout) {
|
||||
return ConstantLiteral(
|
||||
*LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
|
||||
LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D<NativeT>& values) {
|
||||
return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D<NativeT>(values));
|
||||
return ConstantLiteral(LiteralUtil::CreateR2FromArray2D<NativeT>(values));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout(
|
||||
const Array3D<NativeT>& values, const Layout& layout) {
|
||||
return ConstantLiteral(
|
||||
*LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
|
||||
LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
@ -2189,12 +2189,12 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) {
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
|
||||
return ConstantLiteral(builder, *LiteralUtil::CreateR0<NativeT>(value));
|
||||
return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
|
||||
return ConstantLiteral(builder, *LiteralUtil::CreateR1<NativeT>(values));
|
||||
return ConstantLiteral(builder, LiteralUtil::CreateR1<NativeT>(values));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
@ -2207,13 +2207,13 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
|
||||
|
||||
inline XlaOp ConstantR1(XlaBuilder* builder,
|
||||
const tensorflow::core::Bitmap& values) {
|
||||
return ConstantLiteral(builder, *LiteralUtil::CreateR1(values));
|
||||
return ConstantLiteral(builder, LiteralUtil::CreateR1(values));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp ConstantR2(XlaBuilder* builder,
|
||||
std::initializer_list<std::initializer_list<NativeT>> values) {
|
||||
return ConstantLiteral(builder, *LiteralUtil::CreateR2<NativeT>(values));
|
||||
return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
@ -2221,14 +2221,13 @@ XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
|
||||
const Array<NativeT>& values,
|
||||
const Layout& layout) {
|
||||
return ConstantLiteral(
|
||||
builder,
|
||||
*LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
|
||||
builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
|
||||
return ConstantLiteral(builder,
|
||||
*LiteralUtil::CreateFromArray<NativeT>(values));
|
||||
LiteralUtil::CreateFromArray<NativeT>(values));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
@ -2236,15 +2235,14 @@ XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
|
||||
const Array2D<NativeT>& values,
|
||||
const Layout& layout) {
|
||||
return ConstantLiteral(
|
||||
builder,
|
||||
*LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
|
||||
builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
|
||||
const Array2D<NativeT>& values) {
|
||||
return ConstantLiteral(builder,
|
||||
*LiteralUtil::CreateR2FromArray2D<NativeT>(values));
|
||||
LiteralUtil::CreateR2FromArray2D<NativeT>(values));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
@ -2253,7 +2251,7 @@ XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
|
||||
const Layout& layout) {
|
||||
return ConstantLiteral(
|
||||
builder,
|
||||
*LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
|
||||
LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
|
@ -174,9 +174,9 @@ Literal& Literal::operator=(Literal&& other) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) {
|
||||
auto literal = absl::make_unique<Literal>(shape);
|
||||
literal->root_piece_->ForEachMutableSubpiece(
|
||||
Literal LiteralBase::CreateFromShape(const Shape& shape) {
|
||||
Literal literal(shape);
|
||||
literal.root_piece_->ForEachMutableSubpiece(
|
||||
[&](const ShapeIndex& index, Piece* piece) {
|
||||
if (ShapeUtil::IsArray(piece->subshape())) {
|
||||
memset(piece->untyped_data(), 0, piece->size_bytes());
|
||||
@ -278,8 +278,8 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/* static */ StatusOr<std::unique_ptr<Literal>>
|
||||
MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
|
||||
/* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
|
||||
const LiteralProto& proto) {
|
||||
if (!proto.has_shape()) {
|
||||
return InvalidArgument("LiteralProto has no shape");
|
||||
}
|
||||
@ -287,9 +287,9 @@ MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
|
||||
return InvalidArgument("LiteralProto has no layout");
|
||||
}
|
||||
|
||||
auto literal = absl::make_unique<Literal>(proto.shape());
|
||||
Literal literal(proto.shape());
|
||||
|
||||
TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus(
|
||||
TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
|
||||
[&](const ShapeIndex& index, Piece* piece) {
|
||||
const LiteralProto* proto_element = &proto;
|
||||
for (int64 i : index) {
|
||||
@ -556,38 +556,37 @@ void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> LiteralBase::Relayout(
|
||||
const Layout& new_layout, const ShapeIndex& shape_index) const {
|
||||
Literal LiteralBase::Relayout(const Layout& new_layout,
|
||||
const ShapeIndex& shape_index) const {
|
||||
// Create new shape with 'new_layout' set at the given shape index.
|
||||
Shape new_shape = shape();
|
||||
Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
|
||||
TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
|
||||
*subshape->mutable_layout() = new_layout;
|
||||
auto result = absl::make_unique<Literal>(new_shape);
|
||||
TF_CHECK_OK(result->CopyFrom(*this));
|
||||
Literal result(new_shape);
|
||||
TF_CHECK_OK(result.CopyFrom(*this));
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> LiteralBase::Relayout(
|
||||
const Shape& shape_with_layout) const {
|
||||
Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
|
||||
CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
|
||||
<< "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
|
||||
<< " not compatible with literal shape "
|
||||
<< ShapeUtil::HumanString(shape());
|
||||
std::unique_ptr<Literal> result = CreateFromShape(shape_with_layout);
|
||||
Literal result = CreateFromShape(shape_with_layout);
|
||||
ShapeUtil::ForEachSubshape(
|
||||
result->shape(),
|
||||
result.shape(),
|
||||
[this, &result](const Shape& subshape, const ShapeIndex& index) {
|
||||
if (ShapeUtil::IsArray(subshape)) {
|
||||
TF_CHECK_OK(result->CopyFrom(*this,
|
||||
/*dest_shape_index=*/index,
|
||||
/*src_shape_index=*/index));
|
||||
TF_CHECK_OK(result.CopyFrom(*this,
|
||||
/*dest_shape_index=*/index,
|
||||
/*src_shape_index=*/index));
|
||||
}
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
|
||||
StatusOr<Literal> LiteralBase::Broadcast(
|
||||
const Shape& result_shape, absl::Span<const int64> dimensions) const {
|
||||
if (!ShapeUtil::IsArray(shape())) {
|
||||
return InvalidArgument("Broadcast only supports arrays.");
|
||||
@ -598,14 +597,14 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
|
||||
result_shape.dimensions(dimensions[i]));
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> result = absl::make_unique<Literal>(result_shape);
|
||||
Literal result(result_shape);
|
||||
|
||||
// scratch_source_index is temporary storage space for the computed index into
|
||||
// the input literal. We put it here to avoid allocating an std::vector in
|
||||
// every iteration of ShapeUtil::ForEachIndex.
|
||||
std::vector<int64> scratch_source_index(shape().dimensions_size());
|
||||
|
||||
char* dest_data = static_cast<char*>(result->untyped_data());
|
||||
char* dest_data = static_cast<char*>(result.untyped_data());
|
||||
const char* source_data = static_cast<const char*>(untyped_data());
|
||||
const int64 primitive_size =
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
|
||||
@ -627,37 +626,36 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
|
||||
StatusOr<Literal> LiteralBase::Reshape(
|
||||
absl::Span<const int64> dimensions) const {
|
||||
if (!ShapeUtil::IsArray(shape())) {
|
||||
return InvalidArgument("Reshape does not support tuples.");
|
||||
}
|
||||
std::unique_ptr<Literal> output;
|
||||
Literal output;
|
||||
if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
|
||||
output =
|
||||
Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape())));
|
||||
} else {
|
||||
output = CloneToUnique();
|
||||
output = Clone();
|
||||
}
|
||||
// Because the layout is monotonic, we can simply reuse the same sequence of
|
||||
// values without changing their order.
|
||||
*output->mutable_shape_do_not_use() =
|
||||
*output.mutable_shape_do_not_use() =
|
||||
ShapeUtil::MakeShape(shape().element_type(), dimensions);
|
||||
|
||||
int64 elements_before = ShapeUtil::ElementsIn(shape());
|
||||
int64 elements_after = ShapeUtil::ElementsIn(output->shape());
|
||||
int64 elements_after = ShapeUtil::ElementsIn(output.shape());
|
||||
if (elements_before != elements_after) {
|
||||
return InvalidArgument(
|
||||
"Shapes before and after Literal::Reshape have different numbers "
|
||||
"of elements: %s vs %s.",
|
||||
ShapeUtil::HumanString(shape()),
|
||||
ShapeUtil::HumanString(output->shape()));
|
||||
ShapeUtil::HumanString(output.shape()));
|
||||
}
|
||||
return std::move(output);
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> LiteralBase::Transpose(
|
||||
absl::Span<const int64> permutation) const {
|
||||
Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
|
||||
CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
|
||||
CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
|
||||
<< "Given permutation is not a permutation of dimension numbers";
|
||||
@ -687,32 +685,31 @@ std::unique_ptr<Literal> LiteralBase::Transpose(
|
||||
for (auto index : LayoutUtil::MinorToMajor(shape())) {
|
||||
layout->add_minor_to_major(inverse_permutation[index]);
|
||||
}
|
||||
auto new_literal = absl::make_unique<Literal>(permuted_shape);
|
||||
DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()),
|
||||
Literal new_literal(permuted_shape);
|
||||
DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
|
||||
ShapeUtil::ByteSizeOf(shape()));
|
||||
std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes());
|
||||
std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
|
||||
return new_literal;
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
std::unique_ptr<Literal> LiteralBase::SliceInternal(
|
||||
Literal LiteralBase::SliceInternal(
|
||||
const Shape& result_shape, absl::Span<const int64> start_indices) const {
|
||||
auto result_literal = absl::make_unique<Literal>(result_shape);
|
||||
Literal result_literal(result_shape);
|
||||
DimensionVector new_indices(ShapeUtil::Rank(result_shape));
|
||||
result_literal->EachCell<NativeT>(
|
||||
result_literal.EachCell<NativeT>(
|
||||
[&](absl::Span<const int64> indices, NativeT /*value*/) {
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
|
||||
new_indices[i] = indices[i] + start_indices[i];
|
||||
}
|
||||
NativeT value = Get<NativeT>(new_indices);
|
||||
result_literal->Set<NativeT>(indices, value);
|
||||
result_literal.Set<NativeT>(indices, value);
|
||||
});
|
||||
return result_literal;
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> LiteralBase::Slice(
|
||||
absl::Span<const int64> start_indices,
|
||||
absl::Span<const int64> limit_indices) const {
|
||||
Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
|
||||
absl::Span<const int64> limit_indices) const {
|
||||
CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
|
||||
|
||||
DimensionVector result_dimensions;
|
||||
@ -750,12 +747,6 @@ Literal LiteralBase::Clone() const {
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
|
||||
auto result = absl::make_unique<Literal>(shape());
|
||||
TF_CHECK_OK(result->CopyFrom(*this));
|
||||
return result;
|
||||
}
|
||||
|
||||
string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
|
||||
const ShapeIndex& shape_index) const {
|
||||
const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
|
||||
@ -1191,14 +1182,14 @@ void LiteralBase::EachCellAsString(
|
||||
|
||||
namespace {
|
||||
template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
|
||||
std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
|
||||
const LiteralBase& src_literal, const ConverterType& converter) {
|
||||
Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
|
||||
const ConverterType& converter) {
|
||||
CHECK(ShapeUtil::IsArray(src_literal.shape()));
|
||||
auto result_literal = absl::make_unique<Literal>(ShapeUtil::ChangeElementType(
|
||||
Literal result_literal(ShapeUtil::ChangeElementType(
|
||||
src_literal.shape(),
|
||||
primitive_util::NativeToPrimitiveType<NativeDestT>()));
|
||||
auto src_data = src_literal.data<NativeSrcT>();
|
||||
auto dest_data = result_literal->template data<NativeDestT>();
|
||||
auto dest_data = result_literal.template data<NativeDestT>();
|
||||
int64 num_elements = src_literal.element_count();
|
||||
|
||||
for (int64 i = 0; i < num_elements; ++i) {
|
||||
@ -1208,8 +1199,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
|
||||
}
|
||||
|
||||
template <typename NativeSrcT, typename NativeDestT>
|
||||
std::unique_ptr<Literal> ConvertBetweenNativeTypes(
|
||||
const LiteralBase& src_literal) {
|
||||
Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
|
||||
auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
|
||||
return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
|
||||
src_literal, converter);
|
||||
@ -1217,7 +1207,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypes(
|
||||
|
||||
template <typename NativeSrcT, typename NativeDestT>
|
||||
typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)),
|
||||
std::unique_ptr<Literal>>::type
|
||||
Literal>::type
|
||||
BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
|
||||
auto converter = [](NativeSrcT src) {
|
||||
return tensorflow::bit_cast<NativeDestT>(src);
|
||||
@ -1232,20 +1222,20 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
|
||||
// identical sizes higher up.
|
||||
template <typename NativeSrcT, typename NativeDestT>
|
||||
typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
|
||||
std::unique_ptr<Literal>>::type
|
||||
Literal>::type
|
||||
BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
|
||||
LOG(FATAL) << "Invalid bitcast between types of different sizes.";
|
||||
}
|
||||
|
||||
template <PrimitiveType primitive_src_type>
|
||||
std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
|
||||
Literal ConvertToC64(const LiteralBase& src_literal) {
|
||||
CHECK(ShapeUtil::IsArray(src_literal.shape()));
|
||||
auto result_literal = absl::make_unique<Literal>(
|
||||
Literal result_literal(
|
||||
ShapeUtil::ChangeElementType(src_literal.shape(), C64));
|
||||
using NativeSrcT =
|
||||
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
|
||||
absl::Span<const NativeSrcT> src_data = src_literal.data<NativeSrcT>();
|
||||
absl::Span<complex64> dest_data = result_literal->data<complex64>();
|
||||
absl::Span<complex64> dest_data = result_literal.data<complex64>();
|
||||
int64 num_elements = src_literal.element_count();
|
||||
for (int64 i = 0; i < num_elements; ++i) {
|
||||
dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
|
||||
@ -1254,8 +1244,7 @@ std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
|
||||
}
|
||||
|
||||
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
|
||||
std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
|
||||
bool bitcast) {
|
||||
Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
|
||||
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
|
||||
if (bitcast) {
|
||||
return BitcastBetweenNativeTypes<
|
||||
@ -1273,9 +1262,9 @@ std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
|
||||
}
|
||||
|
||||
template <PrimitiveType primitive_src_type>
|
||||
StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
|
||||
const LiteralBase& src_literal, PrimitiveType primitive_dest_type,
|
||||
bool bitcast) {
|
||||
StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
|
||||
PrimitiveType primitive_dest_type,
|
||||
bool bitcast) {
|
||||
switch (primitive_dest_type) {
|
||||
#define CONVERT_IF_TYPES_MATCH(type) \
|
||||
case (type): \
|
||||
@ -1307,12 +1296,12 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
|
||||
PrimitiveType_Name(primitive_dest_type));
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
|
||||
const LiteralBase& literal, PrimitiveType primitive_dest_type,
|
||||
bool bitcast) {
|
||||
StatusOr<Literal> ConvertSwitch(const LiteralBase& literal,
|
||||
PrimitiveType primitive_dest_type,
|
||||
bool bitcast) {
|
||||
TF_RET_CHECK(ShapeUtil::IsArray(literal.shape()));
|
||||
if (literal.shape().element_type() == primitive_dest_type) {
|
||||
return literal.CloneToUnique();
|
||||
return literal.Clone();
|
||||
}
|
||||
switch (literal.shape().element_type()) {
|
||||
#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
|
||||
@ -1342,12 +1331,12 @@ StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> LiteralBase::Convert(
|
||||
StatusOr<Literal> LiteralBase::Convert(
|
||||
PrimitiveType primitive_dest_type) const {
|
||||
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
|
||||
StatusOr<Literal> LiteralBase::BitcastConvert(
|
||||
PrimitiveType primitive_dest_type) const {
|
||||
if (primitive_util::BitWidth(shape().element_type()) !=
|
||||
primitive_util::BitWidth(primitive_dest_type)) {
|
||||
@ -1362,8 +1351,8 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
|
||||
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
|
||||
const Shape& dest_shape, bool round_f32_to_bf16) const {
|
||||
StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape,
|
||||
bool round_f32_to_bf16) const {
|
||||
if (!ShapeUtil::IsTuple(dest_shape)) {
|
||||
if (round_f32_to_bf16 && shape().element_type() == F32 &&
|
||||
dest_shape.element_type() == BF16) {
|
||||
@ -1381,11 +1370,9 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto new_element,
|
||||
element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
|
||||
elements.push_back(std::move(*new_element));
|
||||
elements.push_back(std::move(new_element));
|
||||
}
|
||||
auto converted = absl::make_unique<Literal>();
|
||||
*converted = MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
|
||||
return std::move(converted);
|
||||
return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
|
||||
}
|
||||
|
||||
/* static */ Literal MutableLiteralBase::MoveIntoTuple(
|
||||
|
@ -223,25 +223,21 @@ class LiteralBase {
|
||||
//
|
||||
// TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
|
||||
// the default behavior.
|
||||
StatusOr<std::unique_ptr<Literal>> ConvertToShape(
|
||||
const Shape& dest_shape, bool round_f32_to_bf16 = false) const;
|
||||
StatusOr<Literal> ConvertToShape(const Shape& dest_shape,
|
||||
bool round_f32_to_bf16 = false) const;
|
||||
|
||||
// Converts this literal to another primitive type using a bitcast
|
||||
// conversion. The to and from primitive types must have the same bit
|
||||
// width. Returns an error if the conversion is not possible. This literal
|
||||
// must be array-shaped.
|
||||
StatusOr<std::unique_ptr<Literal>> BitcastConvert(
|
||||
PrimitiveType primitive_dest_type) const;
|
||||
StatusOr<Literal> BitcastConvert(PrimitiveType primitive_dest_type) const;
|
||||
|
||||
// Converts this literal to another primitive type. Returns an error if the
|
||||
// conversion is not possible. This literal must be array-shaped.
|
||||
StatusOr<std::unique_ptr<Literal>> Convert(
|
||||
PrimitiveType primitive_dest_type) const;
|
||||
StatusOr<Literal> Convert(PrimitiveType primitive_dest_type) const;
|
||||
|
||||
// Clones the underlying buffers into a new Literal, or new
|
||||
// std::unique_ptr<Literal>.
|
||||
// Clones the underlying buffers into a new Literal.
|
||||
Literal Clone() const;
|
||||
std::unique_ptr<Literal> CloneToUnique() const;
|
||||
|
||||
// TODO(b/67651157): The methods below which perform computation on Literals
|
||||
// (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
|
||||
@ -259,24 +255,23 @@ class LiteralBase {
|
||||
// Note: this is useful when the client wants to ensure that a value placed in
|
||||
// the XLA allocation tracker has a particular layout; for efficiency
|
||||
// purposes or avoiding unimplemented operation/layout combinations.
|
||||
std::unique_ptr<Literal> Relayout(const Layout& new_layout,
|
||||
const ShapeIndex& shape_index = {}) const;
|
||||
Literal Relayout(const Layout& new_layout,
|
||||
const ShapeIndex& shape_index = {}) const;
|
||||
|
||||
// An overload of Relayout which changes the layout of the entire shape rather
|
||||
// than being limited to a single array within the shape.
|
||||
std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
|
||||
Literal Relayout(const Shape& shape_with_layout) const;
|
||||
|
||||
// Creates a new literal by reshaping this literal to have the given
|
||||
// dimensions. The total number of elements must not change; The
|
||||
// implementation currently only supports monotonic dim0-major layouts.
|
||||
// This literal must be an array.
|
||||
StatusOr<std::unique_ptr<Literal>> Reshape(
|
||||
absl::Span<const int64> dimensions) const;
|
||||
StatusOr<Literal> Reshape(absl::Span<const int64> dimensions) const;
|
||||
|
||||
// Creates a new literal by broadcasting this literal with `dimensions` to
|
||||
// yield a literal of shape `result_shape`.
|
||||
StatusOr<std::unique_ptr<Literal>> Broadcast(
|
||||
const Shape& result_shape, absl::Span<const int64> dimensions) const;
|
||||
StatusOr<Literal> Broadcast(const Shape& result_shape,
|
||||
absl::Span<const int64> dimensions) const;
|
||||
|
||||
// Creates a new literal by reordering the dimensions of this literal.
|
||||
// The given `permutation` must be a permutation of the dimension numbers
|
||||
@ -285,7 +280,7 @@ class LiteralBase {
|
||||
// For example, a transpose call on a literal of shape [3 x 8 x 4] and
|
||||
// `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
|
||||
// This literal must be an array.
|
||||
std::unique_ptr<Literal> Transpose(absl::Span<const int64> permutation) const;
|
||||
Literal Transpose(absl::Span<const int64> permutation) const;
|
||||
|
||||
// Creates a sub-array from this literal by extracting the indices
|
||||
// [start_index, limit_index) of each dimension. The result literal has the
|
||||
@ -293,15 +288,15 @@ class LiteralBase {
|
||||
// start_indices and limit_indices must be the rank of the literal, and the
|
||||
// indices follow the order of the dimensions.
|
||||
// This literal must be an array.
|
||||
std::unique_ptr<Literal> Slice(absl::Span<const int64> start_indices,
|
||||
absl::Span<const int64> limit_indices) const;
|
||||
Literal Slice(absl::Span<const int64> start_indices,
|
||||
absl::Span<const int64> limit_indices) const;
|
||||
|
||||
// Creates a literal with a prepended dimension with bound "times"; e.g. a
|
||||
// f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
|
||||
// literal replicated four times.
|
||||
// This literal must be an array.
|
||||
template <typename NativeT>
|
||||
std::unique_ptr<Literal> Replicate(int64 times) const;
|
||||
Literal Replicate(int64 times) const;
|
||||
|
||||
// Creates a new Literal object with the shape specified as parameter.
|
||||
// The content of the literal values is the default value of the primitive
|
||||
@ -312,7 +307,7 @@ class LiteralBase {
|
||||
// initialization, then reinitialization. Conside if a call to
|
||||
// absl::make_unique<Literal>(shape), followed by the call to
|
||||
// MutableLiteralBase::Populate can be used instead.
|
||||
static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
|
||||
static Literal CreateFromShape(const Shape& shape);
|
||||
|
||||
protected:
|
||||
// A data structure representing a subshape at a particular ShapeIndex within
|
||||
@ -539,8 +534,8 @@ class LiteralBase {
|
||||
|
||||
private:
|
||||
template <typename NativeT>
|
||||
std::unique_ptr<Literal> SliceInternal(
|
||||
const Shape& result_shape, absl::Span<const int64> start_indices) const;
|
||||
Literal SliceInternal(const Shape& result_shape,
|
||||
absl::Span<const int64> start_indices) const;
|
||||
};
|
||||
|
||||
// Abstract base class representing a mutable literal in XLA.
|
||||
@ -687,8 +682,7 @@ class MutableLiteralBase : public LiteralBase {
|
||||
static Literal MoveIntoTuple(absl::Span<Literal> elements);
|
||||
|
||||
// Serialize from a proto.
|
||||
static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
|
||||
const LiteralProto& proto);
|
||||
static StatusOr<Literal> CreateFromProto(const LiteralProto& proto);
|
||||
|
||||
protected:
|
||||
// Returns the piece at the given ShapeIndex.
|
||||
@ -1137,15 +1131,14 @@ void MutableLiteralBase::PopulateWithValue(NativeT value) {
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
|
||||
Literal LiteralBase::Replicate(int64 times) const {
|
||||
DimensionVector bounds = {times};
|
||||
bounds.reserve(shape().dimensions_size() + 1);
|
||||
for (int64 bound : shape().dimensions()) {
|
||||
bounds.push_back(bound);
|
||||
}
|
||||
auto literal = absl::make_unique<Literal>(
|
||||
ShapeUtil::MakeShape(shape().element_type(), bounds));
|
||||
int64 elements = ShapeUtil::ElementsIn(literal->shape());
|
||||
Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds));
|
||||
int64 elements = ShapeUtil::ElementsIn(literal.shape());
|
||||
if (elements == 0) {
|
||||
return literal;
|
||||
}
|
||||
@ -1157,7 +1150,7 @@ std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
|
||||
bool done = false;
|
||||
while (!done) {
|
||||
const auto element = Get<NativeT>(input_indices);
|
||||
literal->Set<NativeT>(output_indices, element);
|
||||
literal.Set<NativeT>(output_indices, element);
|
||||
|
||||
done = true;
|
||||
for (int n = 0; n < output_indices.size(); ++n) {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -45,7 +45,7 @@ using absl::StrCat;
|
||||
// Return a literal with all arrays of type FromNativeT converted to type
|
||||
// ToNativeT in the given literal.
|
||||
template <typename FromNativeT, typename ToNativeT>
|
||||
std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
|
||||
Literal ConvertType(LiteralSlice literal) {
|
||||
// First construct shape of the result.
|
||||
Shape result_shape(literal.shape());
|
||||
ShapeUtil::ForEachMutableSubshape(
|
||||
@ -56,7 +56,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
|
||||
primitive_util::NativeToPrimitiveType<ToNativeT>());
|
||||
}
|
||||
});
|
||||
auto result = absl::make_unique<Literal>(result_shape);
|
||||
Literal result(result_shape);
|
||||
|
||||
// Then copy over the data from 'literal' converting FromNativeT values to
|
||||
// ToNativeT values as necessary.
|
||||
@ -67,14 +67,14 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
|
||||
if (subshape.element_type() ==
|
||||
primitive_util::NativeToPrimitiveType<FromNativeT>()) {
|
||||
auto src = literal.data<FromNativeT>(shape_index);
|
||||
auto dest = result->data<ToNativeT>(shape_index);
|
||||
auto dest = result.data<ToNativeT>(shape_index);
|
||||
for (int64 i = 0; i < src.size(); ++i) {
|
||||
dest[i] = static_cast<ToNativeT>(src[i]);
|
||||
}
|
||||
} else {
|
||||
TF_CHECK_OK(result->CopyFrom(literal,
|
||||
/*dest_shape_index=*/shape_index,
|
||||
/*src_shape_index=*/shape_index));
|
||||
TF_CHECK_OK(result.CopyFrom(literal,
|
||||
/*dest_shape_index=*/shape_index,
|
||||
/*src_shape_index=*/shape_index));
|
||||
}
|
||||
}
|
||||
});
|
||||
@ -83,53 +83,52 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
|
||||
|
||||
} // namespace
|
||||
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromDimensions(
|
||||
/* static */ Literal LiteralUtil::CreateFromDimensions(
|
||||
PrimitiveType primitive_type, absl::Span<const int64> dimensions) {
|
||||
return Literal::CreateFromShape(
|
||||
ShapeUtil::MakeShape(primitive_type, dimensions));
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::ConvertBF16ToF32(
|
||||
/* static */ Literal LiteralUtil::ConvertBF16ToF32(
|
||||
const LiteralSlice& bf16_literal) {
|
||||
return ConvertType<bfloat16, float>(bf16_literal);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::ConvertF32ToBF16(
|
||||
/* static */ Literal LiteralUtil::ConvertF32ToBF16(
|
||||
const LiteralSlice& f32_literal) {
|
||||
return ConvertType<float, bfloat16>(f32_literal);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateToken() {
|
||||
return absl::make_unique<Literal>(ShapeUtil::MakeTokenShape());
|
||||
/* static */ Literal LiteralUtil::CreateToken() {
|
||||
return Literal(ShapeUtil::MakeTokenShape());
|
||||
}
|
||||
|
||||
/* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) {
|
||||
switch (primitive_type) {
|
||||
case U8:
|
||||
return std::move(*LiteralUtil::CreateR0<uint8>(0));
|
||||
return LiteralUtil::CreateR0<uint8>(0);
|
||||
case U32:
|
||||
return std::move(*LiteralUtil::CreateR0<uint32>(0));
|
||||
return LiteralUtil::CreateR0<uint32>(0);
|
||||
case U64:
|
||||
return std::move(*LiteralUtil::CreateR0<uint64>(0));
|
||||
return LiteralUtil::CreateR0<uint64>(0);
|
||||
case S8:
|
||||
return std::move(*LiteralUtil::CreateR0<int8>(0));
|
||||
return LiteralUtil::CreateR0<int8>(0);
|
||||
case S32:
|
||||
return std::move(*LiteralUtil::CreateR0<int32>(0));
|
||||
return LiteralUtil::CreateR0<int32>(0);
|
||||
case S64:
|
||||
return std::move(*LiteralUtil::CreateR0<int64>(0));
|
||||
return LiteralUtil::CreateR0<int64>(0);
|
||||
case F16:
|
||||
return std::move(*LiteralUtil::CreateR0<half>(static_cast<half>(0.0f)));
|
||||
return LiteralUtil::CreateR0<half>(static_cast<half>(0.0f));
|
||||
case BF16:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f)));
|
||||
return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
|
||||
case F32:
|
||||
return std::move(*LiteralUtil::CreateR0<float>(0));
|
||||
return LiteralUtil::CreateR0<float>(0);
|
||||
case F64:
|
||||
return std::move(*LiteralUtil::CreateR0<double>(0));
|
||||
return LiteralUtil::CreateR0<double>(0);
|
||||
case C64:
|
||||
return std::move(*LiteralUtil::CreateR0<complex64>(0));
|
||||
return LiteralUtil::CreateR0<complex64>(0);
|
||||
case PRED:
|
||||
return std::move(*LiteralUtil::CreateR0<bool>(false));
|
||||
return LiteralUtil::CreateR0<bool>(false);
|
||||
case S16:
|
||||
case U16:
|
||||
LOG(FATAL) << "u16/s16 literals not yet implemented";
|
||||
@ -145,30 +144,29 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
|
||||
/* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) {
|
||||
switch (primitive_type) {
|
||||
case U8:
|
||||
return std::move(*LiteralUtil::CreateR0<uint8>(1));
|
||||
return LiteralUtil::CreateR0<uint8>(1);
|
||||
case U32:
|
||||
return std::move(*LiteralUtil::CreateR0<uint32>(1));
|
||||
return LiteralUtil::CreateR0<uint32>(1);
|
||||
case U64:
|
||||
return std::move(*LiteralUtil::CreateR0<uint64>(1));
|
||||
return LiteralUtil::CreateR0<uint64>(1);
|
||||
case S8:
|
||||
return std::move(*LiteralUtil::CreateR0<int8>(1));
|
||||
return LiteralUtil::CreateR0<int8>(1);
|
||||
case S32:
|
||||
return std::move(*LiteralUtil::CreateR0<int32>(1));
|
||||
return LiteralUtil::CreateR0<int32>(1);
|
||||
case S64:
|
||||
return std::move(*LiteralUtil::CreateR0<int64>(1));
|
||||
return LiteralUtil::CreateR0<int64>(1);
|
||||
case F16:
|
||||
return std::move(*LiteralUtil::CreateR0<half>(static_cast<half>(1.0f)));
|
||||
return LiteralUtil::CreateR0<half>(static_cast<half>(1.0f));
|
||||
case BF16:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f)));
|
||||
return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f));
|
||||
case F32:
|
||||
return std::move(*LiteralUtil::CreateR0<float>(1));
|
||||
return LiteralUtil::CreateR0<float>(1);
|
||||
case F64:
|
||||
return std::move(*LiteralUtil::CreateR0<double>(1));
|
||||
return LiteralUtil::CreateR0<double>(1);
|
||||
case C64:
|
||||
return std::move(*LiteralUtil::CreateR0<complex64>(1));
|
||||
return LiteralUtil::CreateR0<complex64>(1);
|
||||
case PRED:
|
||||
return std::move(*LiteralUtil::CreateR0<bool>(true));
|
||||
return LiteralUtil::CreateR0<bool>(true);
|
||||
case S16:
|
||||
case U16:
|
||||
LOG(FATAL) << "u16/s16 literals not yet implemented";
|
||||
@ -184,42 +182,36 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
|
||||
/* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) {
|
||||
switch (primitive_type) {
|
||||
case U8:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min()));
|
||||
return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min());
|
||||
case U32:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min()));
|
||||
return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min());
|
||||
case U64:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min()));
|
||||
return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min());
|
||||
case S8:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min()));
|
||||
return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min());
|
||||
case S32:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min()));
|
||||
return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min());
|
||||
case S64:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min()));
|
||||
return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min());
|
||||
case F32:
|
||||
return std::move(*LiteralUtil::CreateR0<float>(
|
||||
-std::numeric_limits<float>::infinity()));
|
||||
return LiteralUtil::CreateR0<float>(
|
||||
-std::numeric_limits<float>::infinity());
|
||||
case F64:
|
||||
return std::move(*LiteralUtil::CreateR0<double>(
|
||||
-std::numeric_limits<double>::infinity()));
|
||||
return LiteralUtil::CreateR0<double>(
|
||||
-std::numeric_limits<double>::infinity());
|
||||
case C64:
|
||||
LOG(FATAL) << "C64 element type has no minimum value";
|
||||
case PRED:
|
||||
return std::move(*LiteralUtil::CreateR0<bool>(false));
|
||||
return LiteralUtil::CreateR0<bool>(false);
|
||||
case S16:
|
||||
case U16:
|
||||
LOG(FATAL) << "u16/s16 literals not yet implemented";
|
||||
case F16:
|
||||
return std::move(*LiteralUtil::CreateR0<half>(
|
||||
static_cast<half>(-std::numeric_limits<float>::infinity())));
|
||||
return LiteralUtil::CreateR0<half>(
|
||||
static_cast<half>(-std::numeric_limits<float>::infinity()));
|
||||
case BF16:
|
||||
return std::move(*LiteralUtil::CreateR0<bfloat16>(
|
||||
static_cast<bfloat16>(-std::numeric_limits<float>::infinity())));
|
||||
return LiteralUtil::CreateR0<bfloat16>(
|
||||
static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
|
||||
case TUPLE:
|
||||
LOG(FATAL) << "tuple element type has no minimum value";
|
||||
case OPAQUE:
|
||||
@ -232,40 +224,34 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
|
||||
/* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) {
|
||||
switch (primitive_type) {
|
||||
case U8:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max()));
|
||||
return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max());
|
||||
case U32:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max()));
|
||||
return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max());
|
||||
case U64:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max()));
|
||||
return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max());
|
||||
case S8:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max()));
|
||||
return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max());
|
||||
case S32:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max()));
|
||||
return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max());
|
||||
case S64:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max()));
|
||||
return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max());
|
||||
case F32:
|
||||
return std::move(*LiteralUtil::CreateR0<float>(
|
||||
std::numeric_limits<float>::infinity()));
|
||||
return LiteralUtil::CreateR0<float>(
|
||||
std::numeric_limits<float>::infinity());
|
||||
case F64:
|
||||
return std::move(*LiteralUtil::CreateR0<double>(
|
||||
std::numeric_limits<double>::infinity()));
|
||||
return LiteralUtil::CreateR0<double>(
|
||||
std::numeric_limits<double>::infinity());
|
||||
case PRED:
|
||||
return std::move(*LiteralUtil::CreateR0<bool>(true));
|
||||
return LiteralUtil::CreateR0<bool>(true);
|
||||
case S16:
|
||||
case U16:
|
||||
LOG(FATAL) << "u16/s16 literals not yet implemented";
|
||||
case F16:
|
||||
return std::move(*LiteralUtil::CreateR0<half>(
|
||||
static_cast<half>(std::numeric_limits<float>::infinity())));
|
||||
return LiteralUtil::CreateR0<half>(
|
||||
static_cast<half>(std::numeric_limits<float>::infinity()));
|
||||
case BF16:
|
||||
return std::move(*LiteralUtil::CreateR0<bfloat16>(
|
||||
static_cast<bfloat16>(std::numeric_limits<float>::infinity())));
|
||||
return LiteralUtil::CreateR0<bfloat16>(
|
||||
static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
|
||||
case TUPLE:
|
||||
LOG(FATAL) << "tuple element type has no maximum value";
|
||||
case OPAQUE:
|
||||
@ -275,31 +261,29 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
|
||||
}
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
|
||||
/* static */ Literal LiteralUtil::CreateR1(
|
||||
const tensorflow::core::Bitmap& values) {
|
||||
auto literal = absl::make_unique<Literal>(
|
||||
Literal literal(
|
||||
ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())}));
|
||||
literal->PopulateR1(values);
|
||||
literal.PopulateR1(values);
|
||||
return literal;
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1U8(
|
||||
absl::string_view value) {
|
||||
auto literal = absl::make_unique<Literal>(
|
||||
ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
|
||||
/* static */ Literal LiteralUtil::CreateR1U8(absl::string_view value) {
|
||||
Literal literal(ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
|
||||
for (int i = 0; i < value.size(); ++i) {
|
||||
literal->Set<uint8>({i}, value[i]);
|
||||
literal.Set<uint8>({i}, value[i]);
|
||||
}
|
||||
return literal;
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2F32Linspace(
|
||||
float from, float to, int64 rows, int64 cols) {
|
||||
/* static */ Literal LiteralUtil::CreateR2F32Linspace(float from, float to,
|
||||
int64 rows, int64 cols) {
|
||||
auto value = MakeLinspaceArray2D(from, to, rows, cols);
|
||||
return CreateR2FromArray2D(*value);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::ReshapeSlice(
|
||||
/* static */ Literal LiteralUtil::ReshapeSlice(
|
||||
absl::Span<const int64> new_dimensions,
|
||||
absl::Span<const int64> minor_to_major, const LiteralSlice& literal) {
|
||||
int64 new_num_elements = 1;
|
||||
@ -309,13 +293,13 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
|
||||
CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
|
||||
CHECK_EQ(new_dimensions.size(), minor_to_major.size());
|
||||
|
||||
auto new_literal = absl::make_unique<Literal>(
|
||||
Literal new_literal(
|
||||
ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
|
||||
|
||||
// Create a new shape with the given minor-to-major layout. This shape is used
|
||||
// solely for converting linear address to multi-dimensional addresses when
|
||||
// writing elements to the new literal.
|
||||
Shape shape_with_layout = new_literal->shape();
|
||||
Shape shape_with_layout = new_literal.shape();
|
||||
*shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
|
||||
|
||||
// Copy data into new literal, element-by-element.
|
||||
@ -326,40 +310,40 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
|
||||
IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
|
||||
switch (literal.shape().element_type()) {
|
||||
case PRED:
|
||||
new_literal->Set<bool>(to_multi_index,
|
||||
literal.Get<bool>(from_multi_index));
|
||||
new_literal.Set<bool>(to_multi_index,
|
||||
literal.Get<bool>(from_multi_index));
|
||||
break;
|
||||
case U8:
|
||||
new_literal->Set<uint8>(to_multi_index,
|
||||
literal.Get<uint8>(from_multi_index));
|
||||
new_literal.Set<uint8>(to_multi_index,
|
||||
literal.Get<uint8>(from_multi_index));
|
||||
break;
|
||||
case U32:
|
||||
new_literal->Set<uint32>(to_multi_index,
|
||||
literal.Get<uint32>(from_multi_index));
|
||||
new_literal.Set<uint32>(to_multi_index,
|
||||
literal.Get<uint32>(from_multi_index));
|
||||
break;
|
||||
case S32:
|
||||
new_literal->Set<int32>(to_multi_index,
|
||||
literal.Get<int32>(from_multi_index));
|
||||
new_literal.Set<int32>(to_multi_index,
|
||||
literal.Get<int32>(from_multi_index));
|
||||
break;
|
||||
case U64:
|
||||
new_literal->Set<uint64>(to_multi_index,
|
||||
literal.Get<uint64>(from_multi_index));
|
||||
new_literal.Set<uint64>(to_multi_index,
|
||||
literal.Get<uint64>(from_multi_index));
|
||||
break;
|
||||
case S64:
|
||||
new_literal->Set<int64>(to_multi_index,
|
||||
literal.Get<int64>(from_multi_index));
|
||||
new_literal.Set<int64>(to_multi_index,
|
||||
literal.Get<int64>(from_multi_index));
|
||||
break;
|
||||
case F32:
|
||||
new_literal->Set<float>(to_multi_index,
|
||||
literal.Get<float>(from_multi_index));
|
||||
new_literal.Set<float>(to_multi_index,
|
||||
literal.Get<float>(from_multi_index));
|
||||
break;
|
||||
case F64:
|
||||
new_literal->Set<double>(to_multi_index,
|
||||
literal.Get<double>(from_multi_index));
|
||||
new_literal.Set<double>(to_multi_index,
|
||||
literal.Get<double>(from_multi_index));
|
||||
break;
|
||||
case C64:
|
||||
new_literal->Set<complex64>(to_multi_index,
|
||||
literal.Get<complex64>(from_multi_index));
|
||||
new_literal.Set<complex64>(to_multi_index,
|
||||
literal.Get<complex64>(from_multi_index));
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unhandled primitive element type: "
|
||||
@ -376,97 +360,82 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
|
||||
CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0);
|
||||
switch (literal.shape().element_type()) {
|
||||
case PRED:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>()));
|
||||
return LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>());
|
||||
// 8 bit types.
|
||||
case S8:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>()));
|
||||
return LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>());
|
||||
case U8:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>()));
|
||||
return LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>());
|
||||
// 16 bit types.
|
||||
case BF16:
|
||||
return std::move(*LiteralUtil::CreateR0<bfloat16>(
|
||||
literal.GetFirstElement<bfloat16>()));
|
||||
return LiteralUtil::CreateR0<bfloat16>(
|
||||
literal.GetFirstElement<bfloat16>());
|
||||
case F16:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>()));
|
||||
return LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>());
|
||||
case S16:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>()));
|
||||
return LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>());
|
||||
case U16:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>()));
|
||||
return LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>());
|
||||
// 32 bit types.
|
||||
case F32:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>()));
|
||||
return LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>());
|
||||
case S32:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>()));
|
||||
return LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>());
|
||||
case U32:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>()));
|
||||
return LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>());
|
||||
// 64 bit types.
|
||||
case C64:
|
||||
return std::move(*LiteralUtil::CreateR0<complex64>(
|
||||
literal.GetFirstElement<complex64>()));
|
||||
return LiteralUtil::CreateR0<complex64>(
|
||||
literal.GetFirstElement<complex64>());
|
||||
case F64:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>()));
|
||||
return LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>());
|
||||
case S64:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>()));
|
||||
return LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>());
|
||||
case U64:
|
||||
return std::move(
|
||||
*LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>()));
|
||||
return LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>());
|
||||
default:
|
||||
LOG(FATAL) << "Unhandled primitive type "
|
||||
<< literal.shape().element_type();
|
||||
}
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTuple(
|
||||
/* static */ Literal LiteralUtil::MakeTuple(
|
||||
absl::Span<const Literal* const> elements) {
|
||||
std::vector<Shape> element_shapes;
|
||||
for (const auto* element : elements) {
|
||||
element_shapes.push_back(element->shape());
|
||||
}
|
||||
auto literal =
|
||||
absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
|
||||
Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
|
||||
for (int i = 0; i < elements.size(); ++i) {
|
||||
TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
|
||||
TF_CHECK_OK(literal.CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
|
||||
}
|
||||
return literal;
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleFromSlices(
|
||||
/* static */ Literal LiteralUtil::MakeTupleFromSlices(
|
||||
absl::Span<const LiteralSlice> elements) {
|
||||
std::vector<Shape> element_shapes;
|
||||
for (const auto& element : elements) {
|
||||
element_shapes.push_back(element.shape());
|
||||
}
|
||||
auto literal =
|
||||
absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
|
||||
Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
|
||||
for (int i = 0; i < elements.size(); ++i) {
|
||||
TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i}));
|
||||
TF_CHECK_OK(literal.CopyFrom(elements[i], /*dest_shape_index=*/{i}));
|
||||
}
|
||||
return literal;
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleOwned(
|
||||
std::vector<std::unique_ptr<Literal>> elements) {
|
||||
/* static */ Literal LiteralUtil::MakeTupleOwned(
|
||||
std::vector<Literal> elements) {
|
||||
std::vector<Shape> element_shapes;
|
||||
element_shapes.reserve(elements.size());
|
||||
for (const auto& element : elements) {
|
||||
element_shapes.push_back(element->shape());
|
||||
element_shapes.push_back(element.shape());
|
||||
}
|
||||
auto literal =
|
||||
absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
|
||||
Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
|
||||
for (int64 i = 0; i < elements.size(); ++i) {
|
||||
TF_CHECK_OK(
|
||||
literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i}));
|
||||
literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
|
||||
}
|
||||
return literal;
|
||||
}
|
||||
|
@ -69,36 +69,34 @@ class LiteralUtil {
|
||||
// The variants not ending with WithLayout use the default XLA layout for the
|
||||
// literal's linear representation in memory.
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR0(NativeT value);
|
||||
static Literal CreateR0(NativeT value);
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR1(absl::Span<const NativeT> values);
|
||||
static std::unique_ptr<Literal> CreateR1(
|
||||
const tensorflow::core::Bitmap& values);
|
||||
static Literal CreateR1(absl::Span<const NativeT> values);
|
||||
static Literal CreateR1(const tensorflow::core::Bitmap& values);
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR2(
|
||||
static Literal CreateR2(
|
||||
std::initializer_list<std::initializer_list<NativeT>> values);
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR2WithLayout(
|
||||
static Literal CreateR2WithLayout(
|
||||
std::initializer_list<std::initializer_list<NativeT>> values,
|
||||
const Layout& layout);
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR3(
|
||||
std::initializer_list<
|
||||
std::initializer_list<std::initializer_list<NativeT>>>
|
||||
values);
|
||||
static Literal CreateR3(std::initializer_list<
|
||||
std::initializer_list<std::initializer_list<NativeT>>>
|
||||
values);
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR3WithLayout(
|
||||
static Literal CreateR3WithLayout(
|
||||
std::initializer_list<
|
||||
std::initializer_list<std::initializer_list<NativeT>>>
|
||||
values,
|
||||
const Layout& layout);
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR4(
|
||||
static Literal CreateR4(
|
||||
std::initializer_list<std::initializer_list<
|
||||
std::initializer_list<std::initializer_list<NativeT>>>>
|
||||
values);
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR4WithLayout(
|
||||
static Literal CreateR4WithLayout(
|
||||
std::initializer_list<std::initializer_list<
|
||||
std::initializer_list<std::initializer_list<NativeT>>>>
|
||||
values,
|
||||
@ -139,9 +137,10 @@ class LiteralUtil {
|
||||
// [9, 10, 11]: 4.0
|
||||
//
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateSparse(
|
||||
absl::Span<const int64> dimensions, SparseIndexArray indices,
|
||||
absl::Span<const NativeT> values, bool sort = true);
|
||||
static Literal CreateSparse(absl::Span<const int64> dimensions,
|
||||
SparseIndexArray indices,
|
||||
absl::Span<const NativeT> values,
|
||||
bool sort = true);
|
||||
|
||||
// Creates a scalar literal value zero of the given primitive type.
|
||||
static Literal Zero(PrimitiveType primitive_type);
|
||||
@ -155,130 +154,120 @@ class LiteralUtil {
|
||||
static Literal MaxValue(PrimitiveType primitive_type);
|
||||
// Creates a literal of the given shape where each element is `value`.
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
|
||||
static Literal CreateFullWithDescendingLayout(
|
||||
absl::Span<const int64> dimensions, NativeT value);
|
||||
|
||||
// Creates a new literal from an Array type. The variants not ending with
|
||||
// WithLayout use the default XLA layout for the literal's linear
|
||||
// representation in memory.
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values);
|
||||
static Literal CreateFromArray(const Array<NativeT>& values);
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateFromArrayWithLayout(
|
||||
const Array<NativeT>& values, const Layout& layout);
|
||||
static Literal CreateFromArrayWithLayout(const Array<NativeT>& values,
|
||||
const Layout& layout);
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR2FromArray2D(
|
||||
const Array2D<NativeT>& values);
|
||||
static Literal CreateR2FromArray2D(const Array2D<NativeT>& values);
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR2FromArray2DWithLayout(
|
||||
const Array2D<NativeT>& values, const Layout& layout);
|
||||
static Literal CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
|
||||
const Layout& layout);
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR3FromArray3D(
|
||||
const Array3D<NativeT>& values);
|
||||
static Literal CreateR3FromArray3D(const Array3D<NativeT>& values);
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR3FromArray3DWithLayout(
|
||||
const Array3D<NativeT>& values, const Layout& layout);
|
||||
static Literal CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
|
||||
const Layout& layout);
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR4FromArray4D(
|
||||
const Array4D<NativeT>& values);
|
||||
static Literal CreateR4FromArray4D(const Array4D<NativeT>& values);
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR4FromArray4DWithLayout(
|
||||
const Array4D<NativeT>& values, const Layout& layout);
|
||||
static Literal CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
|
||||
const Layout& layout);
|
||||
|
||||
// Creates a new vector of U8s literal value from a string.
|
||||
static std::unique_ptr<Literal> CreateR1U8(absl::string_view value);
|
||||
static Literal CreateR1U8(absl::string_view value);
|
||||
|
||||
// Creates a linspace-populated literal with the given number of rows and
|
||||
// columns.
|
||||
static std::unique_ptr<Literal> CreateR2F32Linspace(float from, float to,
|
||||
int64 rows, int64 cols);
|
||||
static Literal CreateR2F32Linspace(float from, float to, int64 rows,
|
||||
int64 cols);
|
||||
|
||||
// Creates a literal that projects the (x, y) dimensions given in values into
|
||||
// the z dimension given by "projection".
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR3Projected(
|
||||
static Literal CreateR3Projected(
|
||||
std::initializer_list<std::initializer_list<NativeT>> values,
|
||||
int64 projection);
|
||||
|
||||
// Creates a literal that projects the (x, y) dimensions given in values into
|
||||
// the z and p dimensions given.
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR4Projected(
|
||||
static Literal CreateR4Projected(
|
||||
std::initializer_list<std::initializer_list<NativeT>> values,
|
||||
int64 projection_p, int64 projection_z);
|
||||
|
||||
// Returns an identity matrix (rank 2) with the given row and column count.
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> MakeIdentityR2(int64 size);
|
||||
static Literal MakeIdentityR2(int64 size);
|
||||
|
||||
// Returns a tuple literal composed of given literals. Data is copied from the
|
||||
// given elements into the returned literal.
|
||||
static std::unique_ptr<Literal> MakeTuple(
|
||||
absl::Span<const Literal* const> elements);
|
||||
static Literal MakeTuple(absl::Span<const Literal* const> elements);
|
||||
|
||||
static std::unique_ptr<Literal> MakeTupleFromSlices(
|
||||
absl::Span<const LiteralSlice> elements);
|
||||
static Literal MakeTupleFromSlices(absl::Span<const LiteralSlice> elements);
|
||||
|
||||
// As above, but intended to be invoked with move semantics; i.e.
|
||||
//
|
||||
// std::vector<std::unique_ptr<Literal>> elements = ...;
|
||||
// std::vector<Literal> elements = ...;
|
||||
// auto result = LiteralUtil::MakeTupleOwned(std::move(elements));
|
||||
//
|
||||
// This would have been declared as an overload, but there is ambiguity
|
||||
// in invocation between the above signature and this one.
|
||||
static std::unique_ptr<Literal> MakeTupleOwned(
|
||||
std::vector<std::unique_ptr<Literal>> elements);
|
||||
static Literal MakeTupleOwned(std::vector<Literal> elements);
|
||||
|
||||
// This overload lets you pass a braced list of unique_ptr<Literal>s to
|
||||
// This overload lets you pass a braced list of Literals to
|
||||
// MakeTupleOwned:
|
||||
//
|
||||
// LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...).
|
||||
//
|
||||
// Simply relying on the MakeTupleOwned(std::vector<unique_ptr<Literal>>)
|
||||
// Simply relying on the MakeTupleOwned(std::vector<Literal>)
|
||||
// overload doesn't work because std::initializer_list's elements are always
|
||||
// const.
|
||||
//
|
||||
// The arguments to this function must all be unique_ptr<Literal>.
|
||||
// The arguments to this function must all be Literal.
|
||||
template <typename... Ts>
|
||||
static std::unique_ptr<Literal> MakeTupleOwned(
|
||||
std::unique_ptr<Ts>... elements) {
|
||||
std::array<std::unique_ptr<Literal>, sizeof...(Ts)> arr{
|
||||
std::move(elements)...};
|
||||
std::vector<std::unique_ptr<Literal>> v;
|
||||
static Literal MakeTupleOwned(Ts... elements) {
|
||||
std::array<Literal, sizeof...(Ts)> arr{std::move(elements)...};
|
||||
std::vector<Literal> v;
|
||||
v.insert(v.begin(), std::make_move_iterator(arr.begin()),
|
||||
std::make_move_iterator(arr.end()));
|
||||
return MakeTupleOwned(std::move(v));
|
||||
}
|
||||
|
||||
// Create a constant token literal. Token types have no value.
|
||||
static std::unique_ptr<Literal> CreateToken();
|
||||
static Literal CreateToken();
|
||||
|
||||
// Creates a new Literal object with its values havings the primitive_type
|
||||
// type, and with dimensions defined by the dimensions parameter.
|
||||
// The content of the literal values is the default value of the primitive
|
||||
// type of literal itself (0 for numeric types, and false for predicates).
|
||||
static std::unique_ptr<Literal> CreateFromDimensions(
|
||||
PrimitiveType primitive_type, absl::Span<const int64> dimensions);
|
||||
static Literal CreateFromDimensions(PrimitiveType primitive_type,
|
||||
absl::Span<const int64> dimensions);
|
||||
|
||||
// If the given literal's data type is bfloat16, converts it to a float
|
||||
// literal; otherwise, returns a copy of it. If the literal is a tuple,
|
||||
// recursively converts its elements.
|
||||
static std::unique_ptr<Literal> ConvertBF16ToF32(
|
||||
const LiteralSlice& bf16_literal);
|
||||
static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal);
|
||||
|
||||
// If the given literal's data type is float, converts it to a bfloat16
|
||||
// literal; otherwise, returns a copy of it. If the literal is a tuple,
|
||||
// recursively converts its elements.
|
||||
static std::unique_ptr<Literal> ConvertF32ToBF16(
|
||||
const LiteralSlice& f32_literal);
|
||||
static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal);
|
||||
|
||||
// Creates a literal with a new shape with the given new dimensions using the
|
||||
// data in the given input literal. For reshaping purposes the (flat) data
|
||||
// buffer of the input literal is assumed to have the given minor_to_major
|
||||
// layout order.
|
||||
static std::unique_ptr<Literal> ReshapeSlice(
|
||||
absl::Span<const int64> new_dimensions,
|
||||
absl::Span<const int64> minor_to_major, const LiteralSlice& literal);
|
||||
static Literal ReshapeSlice(absl::Span<const int64> new_dimensions,
|
||||
absl::Span<const int64> minor_to_major,
|
||||
const LiteralSlice& literal);
|
||||
|
||||
// Creates a literal with the supplied shape, and uses the provided value
|
||||
// generator to populate the literal's values.
|
||||
@ -286,7 +275,7 @@ class LiteralUtil {
|
||||
template <
|
||||
PrimitiveType type,
|
||||
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
|
||||
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
|
||||
static StatusOr<Literal> CreateRandomLiteral(
|
||||
const Shape& shape,
|
||||
const std::function<T(absl::Span<const int64>)>& generator);
|
||||
|
||||
@ -297,8 +286,8 @@ class LiteralUtil {
|
||||
template <
|
||||
PrimitiveType type, typename E,
|
||||
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
|
||||
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
|
||||
const Shape& shape, E* engine, T mean, T stddev);
|
||||
static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, E* engine,
|
||||
T mean, T stddev);
|
||||
|
||||
// Creates a literal with the supplied shape, and initializes the literal
|
||||
// values using a normal distribution with given mean and stddev standard
|
||||
@ -307,8 +296,8 @@ class LiteralUtil {
|
||||
template <
|
||||
PrimitiveType type,
|
||||
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
|
||||
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
|
||||
const Shape& shape, T mean, T stddev);
|
||||
static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, T mean,
|
||||
T stddev);
|
||||
|
||||
//
|
||||
// End of factory methods.
|
||||
@ -322,44 +311,43 @@ class LiteralUtil {
|
||||
std::ostream& operator<<(std::ostream& out, const Literal& literal);
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR0(NativeT value) {
|
||||
auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShape(
|
||||
/* static */ Literal LiteralUtil::CreateR0(NativeT value) {
|
||||
Literal literal(ShapeUtil::MakeShape(
|
||||
primitive_util::NativeToPrimitiveType<NativeT>(), {}));
|
||||
literal->Set({}, value);
|
||||
literal.Set({}, value);
|
||||
return literal;
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
|
||||
absl::Span<const NativeT> values) {
|
||||
auto literal = absl::make_unique<Literal>(
|
||||
/* static */ Literal LiteralUtil::CreateR1(absl::Span<const NativeT> values) {
|
||||
Literal literal(
|
||||
ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
|
||||
{static_cast<int64>(values.size())}));
|
||||
literal->PopulateR1(values);
|
||||
literal.PopulateR1(values);
|
||||
return literal;
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2WithLayout(
|
||||
/* static */ Literal LiteralUtil::CreateR2WithLayout(
|
||||
std::initializer_list<std::initializer_list<NativeT>> values,
|
||||
const Layout& layout) {
|
||||
auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShapeWithLayout(
|
||||
Literal literal(ShapeUtil::MakeShapeWithLayout(
|
||||
primitive_util::NativeToPrimitiveType<NativeT>(),
|
||||
{static_cast<int64>(values.size()),
|
||||
static_cast<int64>(values.begin()->size())},
|
||||
AsInt64Slice(layout.minor_to_major())));
|
||||
literal->PopulateR2(values);
|
||||
literal.PopulateR2(values);
|
||||
return literal;
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2(
|
||||
/* static */ Literal LiteralUtil::CreateR2(
|
||||
std::initializer_list<std::initializer_list<NativeT>> values) {
|
||||
return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3WithLayout(
|
||||
/* static */ Literal LiteralUtil::CreateR3WithLayout(
|
||||
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
|
||||
values,
|
||||
const Layout& layout) {
|
||||
@ -384,14 +372,14 @@ template <typename NativeT>
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3(
|
||||
/* static */ Literal LiteralUtil::CreateR3(
|
||||
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
|
||||
values) {
|
||||
return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4WithLayout(
|
||||
/* static */ Literal LiteralUtil::CreateR4WithLayout(
|
||||
std::initializer_list<std::initializer_list<
|
||||
std::initializer_list<std::initializer_list<NativeT>>>>
|
||||
values,
|
||||
@ -422,23 +410,22 @@ template <typename NativeT>
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateSparse(
|
||||
/* static */ Literal LiteralUtil::CreateSparse(
|
||||
absl::Span<const int64> dimensions, SparseIndexArray indices,
|
||||
absl::Span<const NativeT> values, bool sort) {
|
||||
int64 num_elements = values.size();
|
||||
int64 rank = dimensions.size();
|
||||
CHECK_EQ(num_elements, indices.index_count());
|
||||
CHECK_EQ(rank, indices.rank());
|
||||
auto literal =
|
||||
absl::make_unique<Literal>(ShapeUtil::MakeShapeWithSparseLayout(
|
||||
primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
|
||||
indices.max_indices()));
|
||||
literal->PopulateSparse(indices, values, sort);
|
||||
Literal literal(ShapeUtil::MakeShapeWithSparseLayout(
|
||||
primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
|
||||
indices.max_indices()));
|
||||
literal.PopulateSparse(indices, values, sort);
|
||||
return literal;
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4(
|
||||
/* static */ Literal LiteralUtil::CreateR4(
|
||||
std::initializer_list<std::initializer_list<
|
||||
std::initializer_list<std::initializer_list<NativeT>>>>
|
||||
values) {
|
||||
@ -446,50 +433,48 @@ template <typename NativeT>
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArrayWithLayout(
|
||||
/* static */ Literal LiteralUtil::CreateFromArrayWithLayout(
|
||||
const Array<NativeT>& values, const Layout& layout) {
|
||||
auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShapeWithLayout(
|
||||
Literal literal(ShapeUtil::MakeShapeWithLayout(
|
||||
primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
|
||||
AsInt64Slice(layout.minor_to_major())));
|
||||
literal->PopulateFromArray(values);
|
||||
literal.PopulateFromArray(values);
|
||||
return literal;
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArray(
|
||||
/* static */ Literal LiteralUtil::CreateFromArray(
|
||||
const Array<NativeT>& values) {
|
||||
return CreateFromArrayWithLayout(
|
||||
values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal>
|
||||
LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
|
||||
const Layout& layout) {
|
||||
/* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout(
|
||||
const Array2D<NativeT>& values, const Layout& layout) {
|
||||
return CreateFromArrayWithLayout(values, layout);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2FromArray2D(
|
||||
/* static */ Literal LiteralUtil::CreateR2FromArray2D(
|
||||
const Array2D<NativeT>& values) {
|
||||
return CreateFromArray(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal>
|
||||
LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
|
||||
const Layout& layout) {
|
||||
/* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout(
|
||||
const Array3D<NativeT>& values, const Layout& layout) {
|
||||
return CreateFromArrayWithLayout(values, layout);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3FromArray3D(
|
||||
/* static */ Literal LiteralUtil::CreateR3FromArray3D(
|
||||
const Array3D<NativeT>& values) {
|
||||
return CreateFromArray(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3Projected(
|
||||
/* static */ Literal LiteralUtil::CreateR3Projected(
|
||||
std::initializer_list<std::initializer_list<NativeT>> values,
|
||||
int64 projection) {
|
||||
int64 dim0_size = projection;
|
||||
@ -514,7 +499,7 @@ template <typename NativeT>
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4Projected(
|
||||
/* static */ Literal LiteralUtil::CreateR4Projected(
|
||||
std::initializer_list<std::initializer_list<NativeT>> values,
|
||||
int64 projection_p, int64 projection_z) {
|
||||
int64 dim0_size = projection_p;
|
||||
@ -542,21 +527,20 @@ template <typename NativeT>
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4FromArray4D(
|
||||
/* static */ Literal LiteralUtil::CreateR4FromArray4D(
|
||||
const Array4D<NativeT>& values) {
|
||||
return CreateFromArray(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal>
|
||||
LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
|
||||
const Layout& layout) {
|
||||
/* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout(
|
||||
const Array4D<NativeT>& values, const Layout& layout) {
|
||||
return CreateFromArrayWithLayout(values, layout);
|
||||
}
|
||||
|
||||
// Returns an identity matrix (rank 2) with the given row and column count.
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::MakeIdentityR2(int64 size) {
|
||||
/* static */ Literal LiteralUtil::MakeIdentityR2(int64 size) {
|
||||
Array2D<NativeT> array(size, size, 0);
|
||||
for (int64 i = 0; i < size; ++i) {
|
||||
array(i, i) = 1;
|
||||
@ -565,33 +549,29 @@ template <typename NativeT>
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal>
|
||||
LiteralUtil::CreateFullWithDescendingLayout(absl::Span<const int64> dimensions,
|
||||
NativeT value) {
|
||||
auto literal =
|
||||
absl::make_unique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
|
||||
literal->PopulateWithValue(value);
|
||||
/* static */ Literal LiteralUtil::CreateFullWithDescendingLayout(
|
||||
absl::Span<const int64> dimensions, NativeT value) {
|
||||
Literal literal(ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
|
||||
literal.PopulateWithValue(value);
|
||||
return literal;
|
||||
}
|
||||
|
||||
template <PrimitiveType type, typename T>
|
||||
/* static */ StatusOr<std::unique_ptr<Literal>>
|
||||
LiteralUtil::CreateRandomLiteral(
|
||||
/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
|
||||
const Shape& shape,
|
||||
const std::function<T(absl::Span<const int64>)>& generator) {
|
||||
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
|
||||
TF_RET_CHECK(shape.element_type() == type);
|
||||
auto literal = absl::make_unique<Literal>(shape);
|
||||
TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
|
||||
Literal literal(shape);
|
||||
TF_RETURN_IF_ERROR(literal.Populate<NativeT>(
|
||||
[&](absl::Span<const int64> indexes) { return generator(indexes); }));
|
||||
return std::move(literal);
|
||||
}
|
||||
|
||||
template <PrimitiveType type, typename E, typename T>
|
||||
/* static */ StatusOr<std::unique_ptr<Literal>>
|
||||
LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
|
||||
T stddev) {
|
||||
/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
|
||||
const Shape& shape, E* engine, T mean, T stddev) {
|
||||
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
|
||||
std::normal_distribution<NativeT> generator(mean, stddev);
|
||||
return CreateRandomLiteral<type, NativeT>(
|
||||
@ -600,8 +580,8 @@ LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
|
||||
}
|
||||
|
||||
template <PrimitiveType type, typename T>
|
||||
/* static */ StatusOr<std::unique_ptr<Literal>>
|
||||
LiteralUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) {
|
||||
/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
|
||||
const Shape& shape, T mean, T stddev) {
|
||||
std::minstd_rand0 engine;
|
||||
return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
|
||||
}
|
||||
|
@ -39,8 +39,8 @@ PackedLiteralReader::PackedLiteralReader(tensorflow::RandomAccessFile* file)
|
||||
|
||||
PackedLiteralReader::~PackedLiteralReader() { delete file_; }
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
|
||||
const Shape& shape, const Layout* layout) {
|
||||
StatusOr<Literal> PackedLiteralReader::Read(const Shape& shape,
|
||||
const Layout* layout) {
|
||||
VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape)
|
||||
<< " layout: "
|
||||
<< (layout == nullptr ? "<none>" : layout->ShortDebugString());
|
||||
@ -57,11 +57,11 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
|
||||
PrimitiveType_Name(shape.element_type()));
|
||||
}
|
||||
|
||||
auto result = absl::make_unique<Literal>(literal_shape);
|
||||
result->PopulateWithValue(std::numeric_limits<float>::quiet_NaN());
|
||||
Literal result(literal_shape);
|
||||
result.PopulateWithValue(std::numeric_limits<float>::quiet_NaN());
|
||||
|
||||
int64 elements = ShapeUtil::ElementsIn(shape);
|
||||
absl::Span<const float> field = result->data<float>();
|
||||
absl::Span<const float> field = result.data<float>();
|
||||
char* data = absl::bit_cast<char*>(field.data());
|
||||
uint64 bytes = elements * sizeof(float);
|
||||
absl::string_view sp;
|
||||
|
@ -41,8 +41,7 @@ class PackedLiteralReader {
|
||||
//
|
||||
// Layout is optional. If it is not provided, no layout is set on the literal
|
||||
// that is produced.
|
||||
StatusOr<std::unique_ptr<Literal>> Read(const Shape& shape,
|
||||
const Layout* layout = nullptr);
|
||||
StatusOr<Literal> Read(const Shape& shape, const Layout* layout = nullptr);
|
||||
|
||||
// Returns whether the input file has been fully exhausted; i.e. all available
|
||||
// packed literals have been read and we're at the end of the file.
|
||||
|
@ -81,8 +81,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal,
|
||||
return client->TransferToInfeedLocal(literal, device_ordinal);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocalReplica(
|
||||
const Shape& shape, int replica_number) {
|
||||
StatusOr<Literal> TransferFromOutfeedLocalReplica(const Shape& shape,
|
||||
int replica_number) {
|
||||
VLOG(1) << "Outfeeding literal from replica number: " << replica_number
|
||||
<< " shape: " << shape;
|
||||
LocalClient* client = GetOrCreateLocalClient();
|
||||
@ -141,9 +141,8 @@ StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
|
||||
LocalClient* client = GetOrCreateLocalClient();
|
||||
StatusOr<ScopedShapedBuffer> buf = [&] {
|
||||
if (shape_with_layout) {
|
||||
std::unique_ptr<Literal> relaid =
|
||||
argument.Relayout(shape_with_layout.value());
|
||||
return ToBuffer(client, /*device_ordinal=*/0, *relaid);
|
||||
Literal relaid = argument.Relayout(shape_with_layout.value());
|
||||
return ToBuffer(client, /*device_ordinal=*/0, relaid);
|
||||
}
|
||||
return ToBuffer(client, /*device_ordinal=*/0, argument);
|
||||
}();
|
||||
@ -151,7 +150,7 @@ StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
|
||||
return new LocalShapedBuffer(std::move(buf).ValueOrDie());
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> LocalShapedBuffer::ToLiteral() const {
|
||||
StatusOr<Literal> LocalShapedBuffer::ToLiteral() const {
|
||||
LocalClient* client = GetOrCreateLocalClient();
|
||||
return client->ShapedBufferToLiteral(*shaped_buffer());
|
||||
}
|
||||
@ -160,7 +159,7 @@ CompiledLocalComputation::CompiledLocalComputation(
|
||||
std::unique_ptr<LocalExecutable> executable)
|
||||
: executable_(std::move(executable)) {}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
|
||||
StatusOr<Literal> CompiledLocalComputation::Execute(
|
||||
const std::vector<Literal>& arguments,
|
||||
const std::vector<absl::optional<Shape>>& shapes_with_layout) {
|
||||
LocalClient* client = GetOrCreateLocalClient();
|
||||
@ -169,7 +168,7 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
|
||||
|
||||
// Each replica populates a StatusOr result, but only replica zero actually
|
||||
// retrieves its literal value.
|
||||
std::vector<StatusOr<std::unique_ptr<Literal>>> results(GetReplicaCount());
|
||||
std::vector<StatusOr<Literal>> results(GetReplicaCount());
|
||||
{
|
||||
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun",
|
||||
GetReplicaCount());
|
||||
@ -198,9 +197,8 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
|
||||
|
||||
StatusOr<ScopedShapedBuffer> pushed;
|
||||
if (shape_with_layout) {
|
||||
std::unique_ptr<Literal> relaid =
|
||||
argument.Relayout(shape_with_layout.value());
|
||||
pushed = ToBuffer(client, device_ordinal, *relaid);
|
||||
Literal relaid = argument.Relayout(shape_with_layout.value());
|
||||
pushed = ToBuffer(client, device_ordinal, relaid);
|
||||
} else {
|
||||
pushed = ToBuffer(client, device_ordinal, argument);
|
||||
}
|
||||
|
@ -51,8 +51,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number);
|
||||
// Transfers a literal of the given shape from the outfeed of the given replica.
|
||||
//
|
||||
// The replica number is resolved to an appropriate device ordinal.
|
||||
StatusOr<std::unique_ptr<Literal> > TransferFromOutfeedLocalReplica(
|
||||
const Shape& shape, int replica_number);
|
||||
StatusOr<Literal> TransferFromOutfeedLocalReplica(const Shape& shape,
|
||||
int replica_number);
|
||||
|
||||
// Wraps a ScopedShapedBuffer produced by copying a literal "to
|
||||
// device," i.e. copying a literal to a scoped buffer via the local
|
||||
@ -65,7 +65,7 @@ class LocalShapedBuffer {
|
||||
LocalShapedBuffer(ScopedShapedBuffer shaped_buffer);
|
||||
const ScopedShapedBuffer* shaped_buffer() const;
|
||||
|
||||
StatusOr<std::unique_ptr<Literal> > ToLiteral() const;
|
||||
StatusOr<Literal> ToLiteral() const;
|
||||
|
||||
// Transfers ownership of the encapsulated ShapedBuffer to the caller,
|
||||
// analogous to std::unique_ptr::release().
|
||||
@ -117,7 +117,7 @@ class CompiledLocalComputation {
|
||||
// with optionally-specified argument layouts. The literals will be
|
||||
// re-laid out according to the corresponding elements of
|
||||
// shapes_with_layout.
|
||||
StatusOr<std::unique_ptr<Literal> > Execute(
|
||||
StatusOr<Literal> Execute(
|
||||
const std::vector<Literal>& arguments,
|
||||
const std::vector<absl::optional<Shape> >& shapes_with_layout);
|
||||
|
||||
|
@ -216,9 +216,9 @@ tensorflow::ImportNumpy();
|
||||
}
|
||||
|
||||
|
||||
%typemap(out) StatusOr< std::unique_ptr<Literal> > {
|
||||
%typemap(out) StatusOr<Literal> {
|
||||
if ($1.ok()) {
|
||||
std::unique_ptr<Literal> value = $1.ConsumeValueOrDie();
|
||||
Literal value = $1.ConsumeValueOrDie();
|
||||
$result = numpy::PyObjectFromXlaLiteral(*value);
|
||||
} else {
|
||||
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
|
||||
@ -346,25 +346,25 @@ tensorflow::ImportNumpy();
|
||||
|
||||
// Literal
|
||||
|
||||
%typemap(in) const Literal& (StatusOr< std::unique_ptr<Literal> > literal_status) {
|
||||
%typemap(in) const Literal& (StatusOr<Literal> literal_status) {
|
||||
literal_status = numpy::XlaLiteralFromPyObject($input);
|
||||
if (!literal_status.ok()) {
|
||||
PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
|
||||
SWIG_fail;
|
||||
}
|
||||
$1 = literal_status.ValueOrDie().get();
|
||||
$1 = &literal_status.ValueOrDie();
|
||||
}
|
||||
|
||||
%typemap(out) std::unique_ptr<Literal> {
|
||||
%typemap(out) Literal {
|
||||
$result = numpy::PyObjectFromXlaLiteral(*$1);
|
||||
}
|
||||
|
||||
%typemap(out) StatusOr< std::unique_ptr<Literal> > {
|
||||
%typemap(out) StatusOr<Literal> {
|
||||
if (!$1.ok()) {
|
||||
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
|
||||
SWIG_fail;
|
||||
}
|
||||
$result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie());
|
||||
$result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie());
|
||||
}
|
||||
|
||||
%typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
|
||||
@ -375,13 +375,13 @@ tensorflow::ImportNumpy();
|
||||
const int size = PySequence_Size($input);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
PyObject* o = PySequence_GetItem($input, i);
|
||||
StatusOr< std::unique_ptr<Literal> > literal_status = numpy::XlaLiteralFromPyObject(o);
|
||||
StatusOr<Literal> literal_status = numpy::XlaLiteralFromPyObject(o);
|
||||
if (!literal_status.ok()) {
|
||||
PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
|
||||
Py_DECREF(o);
|
||||
SWIG_fail;
|
||||
}
|
||||
temps.push_back(std::move(*literal_status.ConsumeValueOrDie()));
|
||||
temps.push_back(literal_status.ConsumeValueOrDie());
|
||||
Py_DECREF(o);
|
||||
}
|
||||
$1 = &temps;
|
||||
|
@ -368,10 +368,10 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) {
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
|
||||
StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o) {
|
||||
if (PyTuple_Check(o)) {
|
||||
int num_elements = PyTuple_Size(o);
|
||||
std::vector<std::unique_ptr<Literal>> elements;
|
||||
std::vector<Literal> elements;
|
||||
elements.reserve(num_elements);
|
||||
for (int i = 0; i < num_elements; i++) {
|
||||
PyObject* element = PyTuple_GetItem(o, i);
|
||||
@ -389,8 +389,7 @@ StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
|
||||
int np_type = PyArray_TYPE(py_array);
|
||||
auto literal = LiteralUtil::CreateFromDimensions(
|
||||
NumpyTypeToPrimitiveType(np_type), dimensions);
|
||||
TF_RETURN_IF_ERROR(
|
||||
CopyNumpyArrayToLiteral(np_type, py_array, literal.get()));
|
||||
TF_RETURN_IF_ERROR(CopyNumpyArrayToLiteral(np_type, py_array, &literal));
|
||||
return std::move(literal);
|
||||
} else {
|
||||
return InvalidArgument(
|
||||
|
@ -82,7 +82,7 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal);
|
||||
// To avoid transferring ownership of the data buffers that underlie
|
||||
// PyArrays and XLA literals, this function makes deep copies of all
|
||||
// array data.
|
||||
StatusOr<std::unique_ptr<Literal> > XlaLiteralFromPyObject(PyObject* o);
|
||||
StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o);
|
||||
|
||||
// The following functions copy array data from the buffers underlying Numpy
|
||||
// ndarrays into those underlying XLA literals, and vice versa.
|
||||
|
@ -529,13 +529,13 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
|
||||
}
|
||||
|
||||
ordered_input_dimensions[0] =
|
||||
lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0));
|
||||
lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0));
|
||||
ordered_input_dimensions[1] =
|
||||
lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1));
|
||||
lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1));
|
||||
ordered_kernel_dimensions[0] =
|
||||
rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0));
|
||||
rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0));
|
||||
ordered_kernel_dimensions[1] =
|
||||
rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1));
|
||||
rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1));
|
||||
|
||||
std::vector<std::pair<int64, int64>> paddings =
|
||||
MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
|
||||
@ -546,7 +546,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
|
||||
|
||||
WindowDimension dim;
|
||||
dim.set_size(
|
||||
rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)));
|
||||
rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)));
|
||||
dim.set_stride(kernel_stride.first);
|
||||
dim.set_padding_low(paddings[0].first);
|
||||
dim.set_padding_high(paddings[0].second);
|
||||
@ -556,7 +556,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
|
||||
|
||||
WindowDimension dim2;
|
||||
dim2.set_size(
|
||||
rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)));
|
||||
rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)));
|
||||
dim2.set_stride(kernel_stride.second);
|
||||
dim2.set_padding_low(paddings[1].first);
|
||||
dim2.set_padding_high(paddings[1].second);
|
||||
@ -565,7 +565,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
|
||||
*window.add_dimensions() = dim2;
|
||||
|
||||
const Shape& shape = ShapeInference::InferConvolveShape(
|
||||
lhs_literal->shape(), rhs_literal->shape(),
|
||||
lhs_literal.shape(), rhs_literal.shape(),
|
||||
/*feature_group_count=*/1, window, dnums)
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
@ -585,18 +585,18 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
|
||||
auto computation = module.AddEntryComputation(b.Build());
|
||||
|
||||
HloEvaluator evaluator;
|
||||
std::unique_ptr<Literal> result_literal =
|
||||
Literal result_literal =
|
||||
evaluator.Evaluate<const Literal*>(*computation, {}).ConsumeValueOrDie();
|
||||
|
||||
CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4);
|
||||
CHECK_EQ(ShapeUtil::Rank(result_literal.shape()), 4);
|
||||
auto result =
|
||||
absl::make_unique<Array4D<float>>(result_literal->shape().dimensions(0),
|
||||
result_literal->shape().dimensions(1),
|
||||
result_literal->shape().dimensions(2),
|
||||
result_literal->shape().dimensions(3));
|
||||
absl::make_unique<Array4D<float>>(result_literal.shape().dimensions(0),
|
||||
result_literal.shape().dimensions(1),
|
||||
result_literal.shape().dimensions(2),
|
||||
result_literal.shape().dimensions(3));
|
||||
|
||||
result->Each([&](absl::Span<const int64> indices, float* value) {
|
||||
*value = result_literal->Get<float>(indices);
|
||||
*value = result_literal.Get<float>(indices);
|
||||
});
|
||||
|
||||
return result;
|
||||
|
@ -55,7 +55,7 @@ TEST_F(ReferenceUtilTest, TransposeArray2D) {
|
||||
auto result = ReferenceUtil::TransposeArray2D(*matrix_);
|
||||
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
|
||||
LiteralTestUtil::ExpectR2Near<float>({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}},
|
||||
*actual_literal, ErrorSpec(0.0001));
|
||||
actual_literal, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
TEST_F(ReferenceUtilTest, MatmulArray2D) {
|
||||
@ -67,14 +67,14 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) {
|
||||
auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs);
|
||||
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
|
||||
LiteralTestUtil::ExpectR2Near<float>({{58.f, 64.f}, {139.f, 154.f}},
|
||||
*actual_literal, ErrorSpec(0.0001));
|
||||
actual_literal, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
TEST_F(ReferenceUtilTest, ReduceToColArray2D) {
|
||||
auto add = [](float lhs, float rhs) { return lhs + rhs; };
|
||||
auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add);
|
||||
auto actual_literal = LiteralUtil::CreateR1<float>(*result);
|
||||
LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, *actual_literal,
|
||||
LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
@ -82,7 +82,7 @@ TEST_F(ReferenceUtilTest, ReduceToRowArray2D) {
|
||||
auto add = [](float lhs, float rhs) { return lhs + rhs; };
|
||||
auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add);
|
||||
auto actual_literal = LiteralUtil::CreateR1<float>(*result);
|
||||
LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, *actual_literal,
|
||||
LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
@ -90,14 +90,14 @@ TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) {
|
||||
auto result = LiteralUtil::CreateR1<float>(ReferenceUtil::Reduce4DTo1D(
|
||||
Array4D<float>(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2},
|
||||
[](float a, float b) { return a + b; }));
|
||||
LiteralTestUtil::ExpectR1Equal<float>({0}, *result);
|
||||
LiteralTestUtil::ExpectR1Equal<float>({0}, result);
|
||||
}
|
||||
|
||||
TEST_F(ReferenceUtilTest, MapArray2D) {
|
||||
auto identity = [](float value) { return log(exp(value)); };
|
||||
auto result = ReferenceUtil::MapArray2D(*matrix_, identity);
|
||||
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
|
||||
LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal,
|
||||
LiteralTestUtil::ExpectR2NearArray2D(*matrix_, actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
@ -108,7 +108,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) {
|
||||
auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index);
|
||||
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
|
||||
LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}},
|
||||
*actual_literal, ErrorSpec(0.0001));
|
||||
actual_literal, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
TEST_F(ReferenceUtilTest, MapArray4D) {
|
||||
@ -121,7 +121,7 @@ TEST_F(ReferenceUtilTest, MapArray4D) {
|
||||
|
||||
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
|
||||
expected.FillWithMultiples(2.0f);
|
||||
LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal,
|
||||
LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
@ -138,7 +138,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) {
|
||||
|
||||
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
|
||||
expected.Fill(0.0f);
|
||||
LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal,
|
||||
LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
@ -146,16 +146,16 @@ TEST_F(ReferenceUtilTest, SliceArray2D) {
|
||||
auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}});
|
||||
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
|
||||
|
||||
LiteralTestUtil::ExpectR2Near<float>({{1.f, 2.f}, {4.f, 5.f}},
|
||||
*actual_literal, ErrorSpec(0.0001));
|
||||
LiteralTestUtil::ExpectR2Near<float>({{1.f, 2.f}, {4.f, 5.f}}, actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
TEST_F(ReferenceUtilTest, SliceStridedArray2D) {
|
||||
auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}});
|
||||
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
|
||||
|
||||
LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f}, {4.f, 6.f}},
|
||||
*actual_literal, ErrorSpec(0.0001));
|
||||
LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f}, {4.f, 6.f}}, actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
TEST_F(ReferenceUtilTest, SliceArray3D) {
|
||||
@ -167,7 +167,7 @@ TEST_F(ReferenceUtilTest, SliceArray3D) {
|
||||
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
|
||||
|
||||
LiteralTestUtil::ExpectR3Near<float>(
|
||||
{{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal,
|
||||
{{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
@ -180,8 +180,8 @@ TEST_F(ReferenceUtilTest, SliceStridedArray3D) {
|
||||
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
|
||||
|
||||
LiteralTestUtil::ExpectR3Near<float>(
|
||||
{{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}},
|
||||
*actual_literal, ErrorSpec(0.0001));
|
||||
{{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
TEST_F(ReferenceUtilTest, SliceArray4D) {
|
||||
@ -194,7 +194,7 @@ TEST_F(ReferenceUtilTest, SliceArray4D) {
|
||||
|
||||
LiteralTestUtil::ExpectR4Near<float>(
|
||||
{{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}},
|
||||
*actual_literal, ErrorSpec(0.0001));
|
||||
actual_literal, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
TEST_F(ReferenceUtilTest, SliceStridedArray4D) {
|
||||
@ -208,7 +208,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray4D) {
|
||||
LiteralTestUtil::ExpectR4Near<float>(
|
||||
{{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}},
|
||||
{{100.f, 102.f, 104.f}, {110.f, 112.f, 114.f}}}},
|
||||
*actual_literal, ErrorSpec(0.0001));
|
||||
actual_literal, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) {
|
||||
@ -220,7 +220,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) {
|
||||
|
||||
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
|
||||
|
||||
LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
|
||||
LiteralTestUtil::ExpectR3NearArray3D<float>(expected, actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
@ -233,7 +233,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithValidPadding) {
|
||||
|
||||
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
|
||||
|
||||
LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
|
||||
LiteralTestUtil::ExpectR3NearArray3D<float>(expected, actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
@ -268,7 +268,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) {
|
||||
|
||||
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
|
||||
|
||||
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
|
||||
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
@ -302,7 +302,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) {
|
||||
|
||||
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
|
||||
|
||||
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
|
||||
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
@ -358,7 +358,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) {
|
||||
|
||||
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
|
||||
|
||||
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
|
||||
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
@ -411,7 +411,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) {
|
||||
|
||||
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
|
||||
|
||||
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
|
||||
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
@ -424,7 +424,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) {
|
||||
[](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c);
|
||||
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual);
|
||||
LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}},
|
||||
*actual_literal, ErrorSpec(0.0001));
|
||||
actual_literal, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -95,12 +95,11 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) {
|
||||
std::vector<float> expected = {
|
||||
1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796,
|
||||
6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327};
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
LiteralUtil::CreateR1<float>(expected);
|
||||
Literal expected_literal = LiteralUtil::CreateR1<float>(expected);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer(
|
||||
computation, {}, nullptr));
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal,
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(expected_literal, result_literal,
|
||||
ErrorSpec(0.0001)));
|
||||
}
|
||||
|
||||
|
@ -205,7 +205,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
|
||||
HloInstruction* zero =
|
||||
computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique()));
|
||||
LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
|
||||
HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
|
||||
Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape());
|
||||
return computation_->AddInstruction(HloInstruction::CreateReduce(
|
||||
@ -527,7 +527,7 @@ static HloInstruction* BuildTupleConstant(HloComputation* computation,
|
||||
return computation->AddInstruction(HloInstruction::CreateTuple(elems));
|
||||
} else {
|
||||
return computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(literal.CloneToUnique()));
|
||||
HloInstruction::CreateConstant(literal.Clone()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -546,7 +546,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
|
||||
// If a literal is all the same element replace it with a scalar broadcast.
|
||||
if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
|
||||
constant->literal().IsAllFirst()) {
|
||||
std::unique_ptr<Literal> unique_scalar = absl::make_unique<Literal>(
|
||||
Literal unique_scalar(
|
||||
LiteralUtil::GetFirstScalarLiteral(constant->literal()));
|
||||
HloInstruction* scalar = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(unique_scalar)));
|
||||
@ -676,7 +676,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
|
||||
return Status::OK();
|
||||
}
|
||||
auto inverse = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant((new_literal.CloneToUnique())));
|
||||
HloInstruction::CreateConstant((new_literal.Clone())));
|
||||
TF_ASSIGN_OR_RETURN(auto new_divide,
|
||||
MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
|
||||
return ReplaceInstruction(divide, new_divide);
|
||||
@ -1469,7 +1469,7 @@ Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) {
|
||||
auto* iota = Cast<HloIotaInstruction>(instruction);
|
||||
if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
|
||||
auto zero = computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(iota->shape().element_type()).CloneToUnique()));
|
||||
LiteralUtil::Zero(iota->shape().element_type()).Clone()));
|
||||
return ReplaceWithNewInstruction(
|
||||
iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {}));
|
||||
}
|
||||
@ -1572,7 +1572,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
|
||||
CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
|
||||
if (IsAll(rhs, 0)) {
|
||||
auto one = HloInstruction::CreateConstant(
|
||||
LiteralUtil::One(power->shape().element_type()).CloneToUnique());
|
||||
LiteralUtil::One(power->shape().element_type()).Clone());
|
||||
std::unique_ptr<HloInstruction> ones;
|
||||
if (ShapeUtil::IsScalar(power->shape())) {
|
||||
ones = std::move(one);
|
||||
@ -1607,7 +1607,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
|
||||
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
|
||||
if (IsAll(rhs, -1)) {
|
||||
auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::One(rhs->shape().element_type()).CloneToUnique()));
|
||||
LiteralUtil::One(rhs->shape().element_type()).Clone()));
|
||||
|
||||
// Explicitly broadcast scalar 1 to the output shape, to avoid implicit
|
||||
// broadcast in divide HLO as we are trying to eliminate implicit
|
||||
@ -2062,7 +2062,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
|
||||
if (!converted_pad_literal.ok()) {
|
||||
return false;
|
||||
}
|
||||
return *converted_pad_literal.ValueOrDie() == reduce_init_literal;
|
||||
return converted_pad_literal.ValueOrDie() == reduce_init_literal;
|
||||
};
|
||||
// The pad value is usually a constant, so we handle that case and do not
|
||||
// try to get more fancy about proving equivalence in cases beyond that.
|
||||
@ -2223,8 +2223,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
|
||||
HloInstruction::CreateBroadcast(
|
||||
convolution->shape(),
|
||||
computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(convolution->shape().element_type())
|
||||
.CloneToUnique())),
|
||||
LiteralUtil::Zero(convolution->shape().element_type()))),
|
||||
{}));
|
||||
}
|
||||
|
||||
|
@ -2932,9 +2932,9 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
const float constant_scalar = 7.3f;
|
||||
std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
|
||||
std::unique_ptr<Literal> value = LiteralUtil::MakeTuple(
|
||||
{LiteralUtil::CreateR0<float>(constant_scalar).get(),
|
||||
LiteralUtil::CreateR1<float>(constant_vector).get()});
|
||||
Literal elements[] = {LiteralUtil::CreateR0<float>(constant_scalar),
|
||||
LiteralUtil::CreateR1<float>(constant_vector)};
|
||||
Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
|
||||
|
||||
auto computation = module().AddEntryComputation(builder.Build());
|
||||
|
@ -205,11 +205,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
|
||||
const Shape feature_shape = scale->shape();
|
||||
|
||||
auto zero_literal = LiteralUtil::CreateR0(0.0f);
|
||||
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
|
||||
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
|
||||
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
|
||||
|
||||
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
|
||||
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
|
||||
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
|
||||
auto epsilon = add(HloInstruction::CreateBroadcast(
|
||||
operand_shape,
|
||||
add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {}));
|
||||
@ -331,7 +331,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
|
||||
const Shape feature_shape = scale->shape();
|
||||
|
||||
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
|
||||
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
|
||||
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
|
||||
auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
operand_shape,
|
||||
computation_->AddInstruction(
|
||||
@ -464,11 +464,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
|
||||
const int64 elements_per_feature_int64 = size_in_elements / feature_count;
|
||||
|
||||
auto zero_literal = LiteralUtil::CreateR0(0.0f);
|
||||
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
|
||||
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
|
||||
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
|
||||
|
||||
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
|
||||
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
|
||||
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
|
||||
auto epsilon_scalar =
|
||||
add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
|
||||
auto epsilon_activation = add(
|
||||
@ -560,7 +560,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
|
||||
auto elements_per_feature_literal =
|
||||
LiteralUtil::CreateR0<float>(elements_per_feature_int64);
|
||||
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
|
||||
elements_per_feature_literal->Convert(ptype));
|
||||
elements_per_feature_literal.Convert(ptype));
|
||||
auto elements_per_feature = add(
|
||||
HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
|
||||
auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output,
|
||||
|
@ -163,10 +163,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
|
||||
EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
|
||||
EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||
*LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_a)),
|
||||
LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_a)),
|
||||
dot->operand(0)->literal()));
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(
|
||||
*LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_b)),
|
||||
LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_b)),
|
||||
dot->operand(1)->literal()));
|
||||
}
|
||||
|
||||
|
@ -1245,9 +1245,10 @@ TEST_F(BufferAssignmentTest, TupleConstantAsOutput) {
|
||||
// Test that a tuple constant which is forwarded to the computation output
|
||||
// is properly handled.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Literal elements[] = {LiteralUtil::CreateR0<int64>(0),
|
||||
LiteralUtil::CreateR0<int64>(1)};
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
|
||||
LiteralUtil::CreateR0<int64>(1).get()})));
|
||||
LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
@ -440,15 +440,15 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
|
||||
// computation. The buffer containing {0, 1} is copied by GetTupleElement, and
|
||||
// the buffers containing {3} and 3 are dead.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto inner_tuple0 =
|
||||
LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
|
||||
LiteralUtil::CreateR0<int64>(1).get()});
|
||||
auto inner_tuple1 =
|
||||
LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(3).get()});
|
||||
Literal elements0[] = {LiteralUtil::CreateR0<int64>(0),
|
||||
LiteralUtil::CreateR0<int64>(1)};
|
||||
auto inner_tuple0 = LiteralUtil::MakeTuple({&elements0[0], &elements0[1]});
|
||||
Literal element1 = LiteralUtil::CreateR0<int64>(3);
|
||||
auto inner_tuple1 = LiteralUtil::MakeTuple({&element1});
|
||||
auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()})));
|
||||
LiteralUtil::MakeTuple({&inner_tuple0, &inner_tuple1})));
|
||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
inner_tuple0->shape(), tuple_constant, 0));
|
||||
inner_tuple0.shape(), tuple_constant, 0));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
@ -214,8 +214,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
|
||||
expanded_filter = add(HloInstruction::CreateConcatenate(
|
||||
expanded_filter_shape, concat_operands, input_feature_dim));
|
||||
}
|
||||
auto zero = add(HloInstruction::CreateConstant(absl::make_unique<Literal>(
|
||||
LiteralUtil::Zero(expanded_filter_shape.element_type()))));
|
||||
auto zero = add(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(expanded_filter_shape.element_type())));
|
||||
auto zero_filter =
|
||||
add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
|
||||
auto new_filter = add(
|
||||
|
@ -45,7 +45,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto input_literal1 = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
|
||||
auto input_literal2 = LiteralUtil::CreateR1<float>({-2.0, -42.0, 2.0});
|
||||
Shape vshape = input_literal1->shape();
|
||||
Shape vshape = input_literal1.shape();
|
||||
|
||||
auto input1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(input_literal1)));
|
||||
@ -78,13 +78,13 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
|
||||
auto result = ExecuteAndTransfer(module->Clone(), {});
|
||||
|
||||
// Check the output correctness.
|
||||
LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, *result, error_spec_);
|
||||
LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, result, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
|
||||
Shape vshape = input_literal->shape();
|
||||
Shape vshape = input_literal.shape();
|
||||
|
||||
auto input = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(input_literal)));
|
||||
@ -125,8 +125,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
|
||||
auto result = ExecuteAndTransfer(module->Clone(), {});
|
||||
|
||||
// Check the output correctness.
|
||||
LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, *result,
|
||||
error_spec_);
|
||||
LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, result, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
|
||||
@ -135,7 +134,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
|
||||
auto module = CreateNewModule();
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
|
||||
Shape vshape = input_literal->shape();
|
||||
Shape vshape = input_literal.shape();
|
||||
|
||||
auto input = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(input_literal)));
|
||||
@ -213,7 +212,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
|
||||
|
||||
// Check the output correctness.
|
||||
LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0, 14.0, 40.0, 40.0},
|
||||
*result, error_spec_);
|
||||
result, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
|
||||
@ -232,7 +231,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
|
||||
// each fusion instruction to ensure that negate is not duplicated.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto input_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
|
||||
Shape vshape = input_literal->shape();
|
||||
Shape vshape = input_literal.shape();
|
||||
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(input_literal)));
|
||||
|
@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase {
|
||||
};
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedR0Bool) {
|
||||
TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
|
||||
TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true));
|
||||
}
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedR1U32) {
|
||||
TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
|
||||
TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
|
||||
}
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedR2F32) {
|
||||
TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
|
||||
TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
|
||||
}
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedR3F32) {
|
||||
TestInfeedRoundTrip(
|
||||
*LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
|
||||
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
|
||||
LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
|
||||
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
|
||||
}
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
|
||||
const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
|
||||
const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
|
||||
|
||||
TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
|
||||
TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
|
||||
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
|
||||
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
|
||||
r3_dim0minor));
|
||||
|
||||
TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
|
||||
TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
|
||||
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
|
||||
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
|
||||
r3_dim0major));
|
||||
}
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedR4S32) {
|
||||
TestInfeedRoundTrip(*LiteralUtil::CreateR4(
|
||||
TestInfeedRoundTrip(LiteralUtil::CreateR4(
|
||||
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
|
||||
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
|
||||
}
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedTuple) {
|
||||
TestInfeedRoundTrip(
|
||||
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
|
||||
LiteralUtil::CreateR0<bool>(false).get()}));
|
||||
TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR1<uint32>({1, 2, 3}),
|
||||
LiteralUtil::CreateR0<bool>(false)}));
|
||||
}
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
|
||||
TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
|
||||
TestInfeedRoundTrip(LiteralUtil::MakeTuple({}));
|
||||
}
|
||||
|
||||
// Tests Infeed operation used in a while loop, as in the code below. The
|
||||
@ -157,21 +157,21 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) {
|
||||
|
||||
// Send 5 Infeed data of shape F32[3].
|
||||
ASSERT_IS_OK(
|
||||
client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({1, 2, 3})));
|
||||
client_->TransferToInfeed(LiteralUtil::CreateR1<float>({1, 2, 3})));
|
||||
ASSERT_IS_OK(
|
||||
client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({4, 5, 6})));
|
||||
client_->TransferToInfeed(LiteralUtil::CreateR1<float>({4, 5, 6})));
|
||||
ASSERT_IS_OK(
|
||||
client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({7, 8, 9})));
|
||||
client_->TransferToInfeed(LiteralUtil::CreateR1<float>({7, 8, 9})));
|
||||
ASSERT_IS_OK(
|
||||
client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({10, 11, 12})));
|
||||
client_->TransferToInfeed(LiteralUtil::CreateR1<float>({10, 11, 12})));
|
||||
ASSERT_IS_OK(
|
||||
client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({13, 14, 15})));
|
||||
client_->TransferToInfeed(LiteralUtil::CreateR1<float>({13, 14, 15})));
|
||||
|
||||
delete computation_thread; // Joins the thread.
|
||||
auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
|
||||
|
||||
// Only the first 3 infeed data should be added.
|
||||
LiteralTestUtil::ExpectR0Near<float>(45.0f, *result_literal, ErrorSpec{1e-7});
|
||||
LiteralTestUtil::ExpectR0Near<float>(45.0f, result_literal, ErrorSpec{1e-7});
|
||||
}
|
||||
|
||||
// Tests two Infeed operations with a total order. The order is enforced by
|
||||
@ -250,17 +250,17 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
|
||||
|
||||
// Send the first 4 Infeed data of shape Tuple(F32[2], PRED).
|
||||
ASSERT_IS_OK(client_->TransferToInfeed(
|
||||
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
|
||||
LiteralUtil::CreateR0<bool>(true).get()})));
|
||||
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2}),
|
||||
LiteralUtil::CreateR0<bool>(true)})));
|
||||
ASSERT_IS_OK(client_->TransferToInfeed(
|
||||
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({3, 4}).get(),
|
||||
LiteralUtil::CreateR0<bool>(true).get()})));
|
||||
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({3, 4}),
|
||||
LiteralUtil::CreateR0<bool>(true)})));
|
||||
ASSERT_IS_OK(client_->TransferToInfeed(
|
||||
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({5, 6}).get(),
|
||||
LiteralUtil::CreateR0<bool>(true).get()})));
|
||||
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({5, 6}),
|
||||
LiteralUtil::CreateR0<bool>(true)})));
|
||||
ASSERT_IS_OK(client_->TransferToInfeed(
|
||||
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8}).get(),
|
||||
LiteralUtil::CreateR0<bool>(false).get()})));
|
||||
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8}),
|
||||
LiteralUtil::CreateR0<bool>(false)})));
|
||||
|
||||
// Asynchronously launch the execution on the device.
|
||||
std::unique_ptr<GlobalData> result;
|
||||
@ -275,21 +275,21 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
|
||||
// Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED).
|
||||
sleep(1);
|
||||
ASSERT_IS_OK(client_->TransferToInfeed(
|
||||
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
|
||||
LiteralUtil::CreateR0<bool>(true).get()})));
|
||||
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2, 3}),
|
||||
LiteralUtil::CreateR0<bool>(true)})));
|
||||
ASSERT_IS_OK(client_->TransferToInfeed(
|
||||
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8, 9}).get(),
|
||||
LiteralUtil::CreateR0<bool>(false).get()})));
|
||||
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8, 9}),
|
||||
LiteralUtil::CreateR0<bool>(false)})));
|
||||
ASSERT_IS_OK(client_->TransferToInfeed(
|
||||
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
|
||||
LiteralUtil::CreateR0<bool>(true).get()})));
|
||||
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({4, 5, 6}),
|
||||
LiteralUtil::CreateR0<bool>(true)})));
|
||||
|
||||
// Wait for the execution to be done, and transfer the result.
|
||||
delete computation_thread; // Joins the thread.
|
||||
auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
|
||||
|
||||
// Only the first 6 infeed data should be added.
|
||||
LiteralTestUtil::ExpectR0Near<float>(66.0f, *result_literal, ErrorSpec{1e-7});
|
||||
LiteralTestUtil::ExpectR0Near<float>(66.0f, result_literal, ErrorSpec{1e-7});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -41,8 +41,7 @@ class CpuNoAliasTest : public CpuCodegenTest {};
|
||||
TEST_F(CpuNoAliasTest, Concat) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
||||
std::unique_ptr<Literal> literal =
|
||||
LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
||||
Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
||||
auto param_shape = ShapeUtil::MakeShape(F32, {2, 2});
|
||||
HloInstruction* param_x = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, param_shape, "x"));
|
||||
|
@ -56,9 +56,9 @@ ENTRY main {
|
||||
}
|
||||
)";
|
||||
|
||||
std::unique_ptr<Literal> lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}});
|
||||
std::unique_ptr<Literal> rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}});
|
||||
RunTest(hlo_text, {lhs.get(), rhs.get()});
|
||||
Literal lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}});
|
||||
Literal rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}});
|
||||
RunTest(hlo_text, {&lhs, &rhs});
|
||||
}
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -125,7 +125,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
|
||||
device_memory.size());
|
||||
// Element is array-shaped: transfer array data to device buffer.
|
||||
const auto subliteral = LiteralSlice(literal, index);
|
||||
std::unique_ptr<Literal> relayed_out_literal;
|
||||
Literal relayed_out_literal;
|
||||
const void* source;
|
||||
if (LayoutUtil::Equal(device_subshape.layout(),
|
||||
subliteral.shape().layout())) {
|
||||
@ -138,7 +138,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
|
||||
// Relayout data before transferring.
|
||||
relayed_out_literal = subliteral.Relayout(device_subshape.layout(),
|
||||
/*shape_index=*/{});
|
||||
source = relayed_out_literal->untyped_data();
|
||||
source = relayed_out_literal.untyped_data();
|
||||
TF_RETURN_IF_ERROR(TransferBufferToDevice(
|
||||
stream,
|
||||
/*size=*/GetByteSizeRequirement(device_subshape), source,
|
||||
|
@ -590,7 +590,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveConstantFilter) {
|
||||
Array4D<float> constant_arr(4, 4, 2, 2);
|
||||
constant_arr.FillIota(0);
|
||||
string constant_str =
|
||||
LiteralUtil::CreateR4FromArray4D(constant_arr)->ToString();
|
||||
LiteralUtil::CreateR4FromArray4D(constant_arr).ToString();
|
||||
ParseAndVerifyModule(absl::StrFormat(R"(
|
||||
HloModule test
|
||||
|
||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
|
||||
// We want the input/output feature counts of an f16 conv to be factors of 8,
|
||||
// because without this cudnn can't use tensor cores on the conv.
|
||||
static constexpr int64 kDesiredNumFeaturesFactor = 8;
|
||||
@ -63,8 +62,8 @@ static HloInstruction* PadInstruction(HloInstruction* instr,
|
||||
HloComputation* comp = instr->parent();
|
||||
|
||||
const Shape& shape = instr->shape();
|
||||
auto* zero = comp->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(shape.element_type()).CloneToUnique()));
|
||||
auto* zero = comp->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
|
||||
|
||||
PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape));
|
||||
|
||||
|
@ -68,9 +68,8 @@ HloInstruction* MaybePaddedAndSlicedInput(
|
||||
conv_window.dimensions(i).base_dilation() - 1);
|
||||
}
|
||||
PrimitiveType element_type = input->shape().element_type();
|
||||
HloInstruction* padding =
|
||||
computation->AddInstruction(HloInstruction::CreateConstant(
|
||||
absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
|
||||
HloInstruction* padding = computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
|
||||
input = MakePadHlo(input, padding, padding_config).ValueOrDie();
|
||||
}
|
||||
|
||||
@ -125,9 +124,8 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window,
|
||||
|
||||
HloComputation* computation = kernel->parent();
|
||||
PrimitiveType element_type = kernel->shape().element_type();
|
||||
HloInstruction* padding =
|
||||
computation->AddInstruction(HloInstruction::CreateConstant(
|
||||
absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
|
||||
HloInstruction* padding = computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
|
||||
return MakePadHlo(kernel, padding, padding_config).ValueOrDie();
|
||||
}
|
||||
} // namespace
|
||||
@ -236,9 +234,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
|
||||
// Create a new backward convolution replacing the old one.
|
||||
HloComputation* computation = backward_conv->parent();
|
||||
HloInstruction* output = backward_conv->mutable_operand(1);
|
||||
HloInstruction* padding = computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(absl::make_unique<Literal>(
|
||||
LiteralUtil::Zero(input->shape().element_type()))));
|
||||
HloInstruction* padding =
|
||||
computation->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(input->shape().element_type())));
|
||||
HloInstruction* padded_input =
|
||||
MakePadHlo(input, padding, input_padding_config).ValueOrDie();
|
||||
|
||||
|
@ -38,8 +38,7 @@ class GpuCopyTest : public GpuCodegenTest {};
|
||||
TEST_F(GpuCopyTest, UseMemcpy) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
||||
std::unique_ptr<Literal> literal =
|
||||
LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
||||
Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
||||
HloInstruction* constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(literal)));
|
||||
builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
|
@ -53,40 +53,40 @@ class InfeedTest : public ClientLibraryTestBase {
|
||||
};
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedR0Bool) {
|
||||
TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
|
||||
TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true));
|
||||
}
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedR1U32) {
|
||||
TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
|
||||
TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
|
||||
}
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedR2F32) {
|
||||
TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
|
||||
TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
|
||||
}
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedR3F32) {
|
||||
TestInfeedRoundTrip(
|
||||
*LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
|
||||
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
|
||||
LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
|
||||
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
|
||||
}
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
|
||||
const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
|
||||
const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
|
||||
|
||||
TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
|
||||
TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
|
||||
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
|
||||
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
|
||||
r3_dim0minor));
|
||||
|
||||
TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
|
||||
TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
|
||||
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
|
||||
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
|
||||
r3_dim0major));
|
||||
}
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedR4S32) {
|
||||
TestInfeedRoundTrip(*LiteralUtil::CreateR4(
|
||||
TestInfeedRoundTrip(LiteralUtil::CreateR4(
|
||||
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
|
||||
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
|
||||
}
|
||||
@ -95,26 +95,26 @@ TEST_F(InfeedTest, SingleInfeedR4S32) {
|
||||
TEST_F(InfeedTest, LargeInfeed) {
|
||||
Array4D<float> array(80, 100, 8, 128);
|
||||
array.FillIota(1.0f);
|
||||
TestInfeedRoundTrip(*LiteralUtil::CreateR4FromArray4D<float>(array));
|
||||
TestInfeedRoundTrip(LiteralUtil::CreateR4FromArray4D<float>(array));
|
||||
}
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedTuple) {
|
||||
TestInfeedRoundTrip(
|
||||
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
|
||||
LiteralUtil::CreateR0<bool>(false).get()}));
|
||||
TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR1<uint32>({1, 2, 3}),
|
||||
LiteralUtil::CreateR0<bool>(false)}));
|
||||
}
|
||||
|
||||
TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
|
||||
TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
|
||||
TestInfeedRoundTrip(LiteralUtil::MakeTuple({}));
|
||||
}
|
||||
|
||||
// Tests that a large tuple infeed can be handled.
|
||||
TEST_F(InfeedTest, SingleInfeedLargeTuple) {
|
||||
Array4D<float> array(40, 100, 8, 128);
|
||||
array.FillIota(1.0f);
|
||||
TestInfeedRoundTrip(*LiteralUtil::MakeTuple(
|
||||
{LiteralUtil::CreateR4FromArray4D<float>(array).get(),
|
||||
LiteralUtil::CreateR0<int32>(5).get()}));
|
||||
TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR4FromArray4D<float>(array),
|
||||
LiteralUtil::CreateR0<int32>(5)}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -76,10 +76,10 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> result = evaluator->TryEvaluate(instruction);
|
||||
Literal result;
|
||||
// Currently we skip unimplemented operations.
|
||||
// TODO(b/35975797): Fold constant computations for more operations.
|
||||
if (result == nullptr) {
|
||||
if (!evaluator->TryEvaluate(instruction, &result)) {
|
||||
VLOG(2) << "Constant folding failed for instruction: "
|
||||
<< instruction->ToString();
|
||||
continue;
|
||||
|
@ -175,7 +175,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto literal,
|
||||
LiteralUtil::CreateRandomLiteral<F32>(
|
||||
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
|
||||
auto literal_clone = literal->Literal::CloneToUnique();
|
||||
auto literal_clone = literal.Clone();
|
||||
HloInstruction* literal_instruction = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(literal)));
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
|
||||
@ -198,7 +198,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
|
||||
root->literal().EachCell<NativeT>(
|
||||
[&](absl::Span<const int64> indices, NativeT value) {
|
||||
std::vector<int64> rindexes = Permute(permutation, indices);
|
||||
matched = matched && (value == literal_clone->Get<NativeT>(rindexes));
|
||||
matched = matched && (value == literal_clone.Get<NativeT>(rindexes));
|
||||
});
|
||||
EXPECT_TRUE(matched);
|
||||
}
|
||||
|
@ -321,18 +321,17 @@ StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
|
||||
padding_config_dim.set_edge_padding_high(zeros_to_append);
|
||||
*padding_config.add_dimensions() = padding_config_dim;
|
||||
|
||||
HloInstruction* zero = computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(absl::make_unique<Literal>(
|
||||
LiteralUtil::Zero(operand->shape().element_type()))));
|
||||
HloInstruction* zero =
|
||||
computation->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(operand->shape().element_type())));
|
||||
return MakePadHlo(operand, zero, padding_config);
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> BroadcastZeros(
|
||||
HloComputation* computation, PrimitiveType element_type,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
HloInstruction* zero =
|
||||
computation->AddInstruction(HloInstruction::CreateConstant(
|
||||
absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
|
||||
HloInstruction* zero = computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
|
||||
return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
|
||||
/*result_shape_bounds=*/broadcast_dimensions);
|
||||
}
|
||||
|
@ -57,10 +57,10 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
|
||||
entry_computation->set_root_instruction(first_1_dims_collapsed);
|
||||
|
||||
HloEvaluator evaluator;
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
|
||||
evaluator.Evaluate<std::unique_ptr<Literal>>(
|
||||
TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
|
||||
evaluator.Evaluate<Literal>(
|
||||
*module, {LiteralUtil::CreateR1<int32>({3, 4})}));
|
||||
CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({3, 4}));
|
||||
CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({3, 4}));
|
||||
}
|
||||
|
||||
TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
|
||||
@ -78,13 +78,13 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
|
||||
|
||||
HloEvaluator evaluator;
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<Literal> result_literal,
|
||||
evaluator.Evaluate<std::unique_ptr<Literal>>(
|
||||
Literal result_literal,
|
||||
evaluator.Evaluate<Literal>(
|
||||
*module,
|
||||
{LiteralUtil::CreateR3<int32>(
|
||||
{{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})}));
|
||||
CHECK_EQ(*result_literal,
|
||||
*LiteralUtil::CreateR2<int32>(
|
||||
CHECK_EQ(result_literal,
|
||||
LiteralUtil::CreateR2<int32>(
|
||||
{{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}}));
|
||||
}
|
||||
|
||||
@ -103,10 +103,10 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) {
|
||||
|
||||
HloEvaluator evaluator;
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<Literal> result_literal,
|
||||
evaluator.Evaluate<std::unique_ptr<Literal>>(
|
||||
*module, {LiteralUtil::CreateR1<int32>({9, 10})}));
|
||||
CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9, 10}}));
|
||||
Literal result_literal,
|
||||
evaluator.Evaluate<Literal>(*module,
|
||||
{LiteralUtil::CreateR1<int32>({9, 10})}));
|
||||
CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9, 10}}));
|
||||
}
|
||||
|
||||
TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
|
||||
@ -124,10 +124,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
|
||||
|
||||
HloEvaluator evaluator;
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<Literal> result_literal,
|
||||
evaluator.Evaluate<std::unique_ptr<Literal>>(
|
||||
*module, {LiteralUtil::CreateR1<int32>({9, 10})}));
|
||||
CHECK_EQ(*result_literal, *LiteralUtil::CreateR3<int32>({{{9, 10}}}));
|
||||
Literal result_literal,
|
||||
evaluator.Evaluate<Literal>(*module,
|
||||
{LiteralUtil::CreateR1<int32>({9, 10})}));
|
||||
CHECK_EQ(result_literal, LiteralUtil::CreateR3<int32>({{{9, 10}}}));
|
||||
}
|
||||
|
||||
TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
|
||||
@ -144,10 +144,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
|
||||
entry_computation->set_root_instruction(with_2_degenerate_dims_prepended);
|
||||
|
||||
HloEvaluator evaluator;
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
|
||||
evaluator.Evaluate<std::unique_ptr<Literal>>(
|
||||
*module, {LiteralUtil::CreateR0<int32>(9)}));
|
||||
CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9}}));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
Literal result_literal,
|
||||
evaluator.Evaluate<Literal>(*module, {LiteralUtil::CreateR0<int32>(9)}));
|
||||
CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9}}));
|
||||
}
|
||||
|
||||
TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
|
||||
@ -165,11 +165,11 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
|
||||
|
||||
HloEvaluator evaluator;
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<Literal> result_literal,
|
||||
evaluator.Evaluate<std::unique_ptr<Literal>>(
|
||||
Literal result_literal,
|
||||
evaluator.Evaluate<Literal>(
|
||||
*module, {LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6})}));
|
||||
CHECK_EQ(*result_literal,
|
||||
*LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
|
||||
CHECK_EQ(result_literal,
|
||||
LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
|
||||
}
|
||||
|
||||
TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
|
||||
@ -187,10 +187,10 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
|
||||
entry_computation->set_root_instruction(zero_padded_param);
|
||||
|
||||
HloEvaluator evaluator;
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
|
||||
evaluator.Evaluate<std::unique_ptr<Literal>>(
|
||||
TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
|
||||
evaluator.Evaluate<Literal>(
|
||||
*module, {LiteralUtil::CreateR1<int32>({3, 4})}));
|
||||
CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
|
||||
CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
|
||||
}
|
||||
|
||||
TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
|
||||
@ -208,10 +208,10 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
|
||||
entry_computation->set_root_instruction(zeros);
|
||||
|
||||
HloEvaluator evaluator;
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
|
||||
evaluator.Evaluate<std::unique_ptr<Literal>>(
|
||||
*module, {LiteralUtil::CreateR0<int32>(0)}));
|
||||
CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
Literal result_literal,
|
||||
evaluator.Evaluate<Literal>(*module, {LiteralUtil::CreateR0<int32>(0)}));
|
||||
CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
|
||||
}
|
||||
|
||||
TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
|
||||
@ -229,11 +229,11 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
|
||||
entry_computation->set_root_instruction(zeros);
|
||||
|
||||
HloEvaluator evaluator;
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
|
||||
evaluator.Evaluate<std::unique_ptr<Literal>>(
|
||||
TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
|
||||
evaluator.Evaluate<Literal>(
|
||||
*module, {LiteralUtil::CreateR0<float>(0.0f)}));
|
||||
CHECK_EQ(*result_literal,
|
||||
*LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
|
||||
CHECK_EQ(result_literal,
|
||||
LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -73,7 +73,7 @@ TEST_F(HloCseTest, CombineTwoConstants) {
|
||||
|
||||
auto result = ExecuteAndTransfer(module->Clone(), {});
|
||||
auto expected = LiteralUtil::CreateR0<float>(84.0);
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
|
||||
}
|
||||
|
||||
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
|
||||
@ -105,7 +105,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
|
||||
|
||||
auto result = ExecuteAndTransfer(module->Clone(), {});
|
||||
auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
|
||||
}
|
||||
|
||||
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
|
||||
@ -135,7 +135,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
|
||||
|
||||
auto result = ExecuteAndTransfer(module->Clone(), {});
|
||||
auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
|
||||
}
|
||||
|
||||
TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
|
||||
|
@ -54,9 +54,8 @@ namespace xla {
|
||||
namespace {
|
||||
|
||||
template <typename OperandT>
|
||||
StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
|
||||
LiteralSlice lhs_literal,
|
||||
LiteralSlice rhs_literal) {
|
||||
StatusOr<Literal> Compare(const Shape& shape, HloOpcode opcode,
|
||||
LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
|
||||
std::function<bool(OperandT, OperandT)> compare_op;
|
||||
switch (opcode) {
|
||||
case HloOpcode::kEq:
|
||||
@ -94,9 +93,9 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
|
||||
<< HloOpcodeString(opcode);
|
||||
}
|
||||
|
||||
auto result = absl::make_unique<Literal>(shape);
|
||||
Literal result(shape);
|
||||
TF_RETURN_IF_ERROR(
|
||||
result->Populate<bool>([&](absl::Span<const int64> multi_index) {
|
||||
result.Populate<bool>([&](absl::Span<const int64> multi_index) {
|
||||
return compare_op(lhs_literal.Get<OperandT>(multi_index),
|
||||
rhs_literal.Get<OperandT>(multi_index));
|
||||
}));
|
||||
@ -105,9 +104,9 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
|
||||
}
|
||||
|
||||
template <>
|
||||
StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
|
||||
const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal,
|
||||
LiteralSlice rhs_literal) {
|
||||
StatusOr<Literal> Compare<complex64>(const Shape& shape, HloOpcode opcode,
|
||||
LiteralSlice lhs_literal,
|
||||
LiteralSlice rhs_literal) {
|
||||
std::function<bool(complex64, complex64)> compare_op;
|
||||
switch (opcode) {
|
||||
case HloOpcode::kEq:
|
||||
@ -125,9 +124,9 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
|
||||
<< HloOpcodeString(opcode);
|
||||
}
|
||||
|
||||
auto result = absl::make_unique<Literal>(shape);
|
||||
Literal result(shape);
|
||||
TF_RETURN_IF_ERROR(
|
||||
result->Populate<bool>([&](absl::Span<const int64> multi_index) {
|
||||
result.Populate<bool>([&](absl::Span<const int64> multi_index) {
|
||||
return compare_op(lhs_literal.Get<complex64>(multi_index),
|
||||
rhs_literal.Get<complex64>(multi_index));
|
||||
}));
|
||||
@ -193,7 +192,7 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations)
|
||||
}
|
||||
|
||||
template <typename LiteralPtr>
|
||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||
StatusOr<Literal> HloEvaluator::Evaluate(
|
||||
const HloModule& module, absl::Span<const LiteralPtr> arg_literals) {
|
||||
XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString());
|
||||
|
||||
@ -206,11 +205,21 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||
TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this));
|
||||
|
||||
return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction())
|
||||
.CloneToUnique();
|
||||
.Clone();
|
||||
}
|
||||
|
||||
template <>
|
||||
StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
|
||||
const HloModule& module, absl::Span<const Literal> arg_literals) {
|
||||
std::vector<const Literal*> arg_literal_ptrs;
|
||||
for (const auto& literal_ptr : arg_literals) {
|
||||
arg_literal_ptrs.push_back(&literal_ptr);
|
||||
}
|
||||
return Evaluate<const Literal*>(module, arg_literal_ptrs);
|
||||
}
|
||||
|
||||
template <typename LiteralPtr>
|
||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||
StatusOr<Literal> HloEvaluator::Evaluate(
|
||||
const HloComputation& computation,
|
||||
absl::Span<const LiteralPtr> arg_literals) {
|
||||
CHECK(computation.parent() != nullptr);
|
||||
@ -224,11 +233,21 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(computation.Accept(this));
|
||||
return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique();
|
||||
return GetEvaluatedLiteralFor(computation.root_instruction()).Clone();
|
||||
}
|
||||
|
||||
template <>
|
||||
StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
|
||||
const HloComputation& computation, absl::Span<const Literal> arg_literals) {
|
||||
std::vector<const Literal*> arg_literal_ptrs;
|
||||
for (const auto& literal_ptr : arg_literals) {
|
||||
arg_literal_ptrs.push_back(&literal_ptr);
|
||||
}
|
||||
return Evaluate<const Literal*>(computation, arg_literal_ptrs);
|
||||
}
|
||||
|
||||
template <typename LiteralPtr>
|
||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||
StatusOr<Literal> HloEvaluator::Evaluate(
|
||||
HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals) {
|
||||
TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
|
||||
|
||||
@ -247,18 +266,27 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||
<< input_literal->ToString();
|
||||
TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape()));
|
||||
|
||||
evaluated_[operand] = input_literal->CloneToUnique();
|
||||
evaluated_[operand] = input_literal->Clone();
|
||||
}
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(Preprocess(instruction));
|
||||
TF_RETURN_IF_ERROR(instruction->Visit(this));
|
||||
TF_RETURN_IF_ERROR(Postprocess(instruction));
|
||||
return GetEvaluatedLiteralFor(instruction).CloneToUnique();
|
||||
return GetEvaluatedLiteralFor(instruction).Clone();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||
HloInstruction* instruction) {
|
||||
template <>
|
||||
StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
|
||||
HloInstruction* instruction, absl::Span<const Literal> arg_literals) {
|
||||
std::vector<const Literal*> arg_literal_ptrs;
|
||||
for (const auto& literal : arg_literals) {
|
||||
arg_literal_ptrs.push_back(&literal);
|
||||
}
|
||||
return Evaluate<const Literal*>(instruction, arg_literal_ptrs);
|
||||
}
|
||||
|
||||
StatusOr<Literal> HloEvaluator::Evaluate(HloInstruction* instruction) {
|
||||
if (instruction->opcode() == HloOpcode::kParameter) {
|
||||
return tensorflow::errors::FailedPrecondition(
|
||||
"Cannot evaluate a parameter.");
|
||||
@ -274,21 +302,22 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||
TF_RETURN_IF_ERROR(Preprocess(instruction));
|
||||
TF_RETURN_IF_ERROR(instruction->Visit(this));
|
||||
TF_RETURN_IF_ERROR(Postprocess(instruction));
|
||||
return GetEvaluatedLiteralFor(instruction).CloneToUnique();
|
||||
return GetEvaluatedLiteralFor(instruction).Clone();
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> HloEvaluator::TryEvaluate(
|
||||
HloInstruction* instruction) {
|
||||
bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) {
|
||||
CHECK(result != nullptr);
|
||||
auto result_or = Evaluate(instruction);
|
||||
if (!result_or.ok()) {
|
||||
VLOG(1) << "TryEvaluate failed:" << result_or.status();
|
||||
return nullptr;
|
||||
return false;
|
||||
}
|
||||
|
||||
return result_or.ConsumeValueOrDie();
|
||||
*result = result_or.ConsumeValueOrDie();
|
||||
return true;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
|
||||
StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
|
||||
const HloInstruction* instruction,
|
||||
const std::unordered_map<const HloInstruction*, const Literal*>&
|
||||
substitutions) {
|
||||
@ -299,7 +328,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
|
||||
owned_operands.push_back(operand->Clone());
|
||||
} else {
|
||||
owned_operands.push_back(
|
||||
HloInstruction::CreateConstant(it->second->CloneToUnique()));
|
||||
HloInstruction::CreateConstant(it->second->Clone()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -316,12 +345,12 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
|
||||
return result;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp(
|
||||
StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
|
||||
HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
|
||||
std::unique_ptr<HloInstruction> lhs_instr =
|
||||
HloInstruction::CreateConstant(lhs.CloneToUnique());
|
||||
HloInstruction::CreateConstant(lhs.Clone());
|
||||
std::unique_ptr<HloInstruction> rhs_instr =
|
||||
HloInstruction::CreateConstant(rhs.CloneToUnique());
|
||||
HloInstruction::CreateConstant(rhs.Clone());
|
||||
|
||||
std::unique_ptr<HloInstruction> cloned_instruction =
|
||||
HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(),
|
||||
@ -331,10 +360,10 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp(
|
||||
return result;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
|
||||
StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
|
||||
HloOpcode opcode, const Literal& operand) {
|
||||
std::unique_ptr<HloInstruction> operand_instr =
|
||||
HloInstruction::CreateConstant(operand.CloneToUnique());
|
||||
HloInstruction::CreateConstant(operand.Clone());
|
||||
|
||||
std::unique_ptr<HloInstruction> cloned_instruction =
|
||||
HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get());
|
||||
@ -343,14 +372,14 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
|
||||
return result;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp(
|
||||
StatusOr<Literal> HloEvaluator::EvaluateDotOp(
|
||||
const DotDimensionNumbers& dim_numbers,
|
||||
const PrecisionConfig& precision_config, const Literal& lhs,
|
||||
const Literal& rhs) {
|
||||
std::unique_ptr<HloInstruction> lhs_instr =
|
||||
HloInstruction::CreateConstant(lhs.CloneToUnique());
|
||||
HloInstruction::CreateConstant(lhs.Clone());
|
||||
std::unique_ptr<HloInstruction> rhs_instr =
|
||||
HloInstruction::CreateConstant(rhs.CloneToUnique());
|
||||
HloInstruction::CreateConstant(rhs.Clone());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape dot_shape,
|
||||
@ -371,7 +400,7 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
|
||||
<< ", but input literal shape is: "
|
||||
<< ShapeUtil::HumanString(input_literal->shape());
|
||||
|
||||
evaluated_[parameter] = input_literal->CloneToUnique();
|
||||
evaluated_[parameter] = input_literal->Clone();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -421,7 +450,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
|
||||
|
||||
for (auto operand : operands) {
|
||||
const Shape& operand_shape = operand->shape();
|
||||
TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
|
||||
TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
|
||||
GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
|
||||
AsInt64Slice(operand_shape.dimensions())));
|
||||
dest_indices[concat_dim] +=
|
||||
@ -824,7 +853,7 @@ class OutputOffsetIndexToInputIndex {
|
||||
// there is one) to `reshaped_start_indices`.
|
||||
static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
|
||||
int64 index_vector_dim, const Literal& start_indices,
|
||||
std::unique_ptr<Literal>* reshaped_start_indices) {
|
||||
Literal* reshaped_start_indices) {
|
||||
if (start_indices.shape().dimensions_size() != index_vector_dim) {
|
||||
return std::cref(start_indices);
|
||||
}
|
||||
@ -834,16 +863,16 @@ static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
|
||||
new_shape.push_back(1);
|
||||
TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
|
||||
start_indices.Reshape(new_shape));
|
||||
return std::cref(**reshaped_start_indices);
|
||||
return std::cref(*reshaped_start_indices);
|
||||
}
|
||||
|
||||
Status HloEvaluator::HandleGather(HloInstruction* gather) {
|
||||
std::unique_ptr<Literal> result = Literal::CreateFromShape(gather->shape());
|
||||
Literal result = Literal::CreateFromShape(gather->shape());
|
||||
const Shape& shape = gather->shape();
|
||||
const GatherDimensionNumbers& dim_numbers =
|
||||
gather->gather_dimension_numbers();
|
||||
const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
|
||||
std::unique_ptr<Literal> reshaped_start_indices;
|
||||
Literal reshaped_start_indices;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const Literal& start_indices,
|
||||
ReshapedGatherIndices(dim_numbers.index_vector_dim(),
|
||||
@ -908,7 +937,7 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
|
||||
DCHECK_LT(input_index[i], operand_shape.dimensions(i));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
result->CopyElementFrom(operand, input_index, output_index));
|
||||
result.CopyElementFrom(operand, input_index, output_index));
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -977,18 +1006,16 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
|
||||
|
||||
const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
|
||||
|
||||
evaluated_[get_tuple_element] = absl::make_unique<Literal>(
|
||||
ShapeUtil::GetTupleElementShape(operand->shape(), index));
|
||||
return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal,
|
||||
/*dest_shape_index=*/{},
|
||||
/*src_shape_index=*/{index});
|
||||
evaluated_[get_tuple_element] =
|
||||
Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index));
|
||||
return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal,
|
||||
/*dest_shape_index=*/{},
|
||||
/*src_shape_index=*/{index});
|
||||
}
|
||||
|
||||
Status HloEvaluator::HandleCopy(HloInstruction* copy) {
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
|
||||
|
||||
auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique();
|
||||
evaluated_[copy] = std::move(result);
|
||||
evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1004,7 +1031,7 @@ Status HloEvaluator::HandleCall(HloInstruction* call) {
|
||||
}
|
||||
|
||||
HloEvaluator embedded_evaluator;
|
||||
std::unique_ptr<Literal> result =
|
||||
Literal result =
|
||||
embedded_evaluator.Evaluate<const Literal*>(*computation, arg_literals)
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
@ -1036,7 +1063,7 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
|
||||
}
|
||||
|
||||
HloEvaluator embedded_evaluator;
|
||||
std::unique_ptr<Literal> result =
|
||||
Literal result =
|
||||
embedded_evaluator
|
||||
.Evaluate<const Literal*>(*readded_computation, arg_literals)
|
||||
.ConsumeValueOrDie();
|
||||
@ -1056,7 +1083,7 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) {
|
||||
auto* false_computation = conditional->false_computation();
|
||||
|
||||
HloEvaluator embedded_evaluator;
|
||||
std::unique_ptr<Literal> result;
|
||||
Literal result;
|
||||
if (pred.Get<bool>({})) {
|
||||
result = embedded_evaluator
|
||||
.Evaluate<const Literal*>(*true_computation,
|
||||
@ -1081,9 +1108,9 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) {
|
||||
// If predicate is of scalar type, no element-wise selection would be needed.
|
||||
if (ShapeUtil::IsScalar(pred.shape())) {
|
||||
if (pred.Get<bool>({})) {
|
||||
evaluated_[select] = on_true.CloneToUnique();
|
||||
evaluated_[select] = on_true.Clone();
|
||||
} else {
|
||||
evaluated_[select] = on_false.CloneToUnique();
|
||||
evaluated_[select] = on_false.Clone();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1097,9 +1124,9 @@ Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) {
|
||||
const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2));
|
||||
|
||||
if (pred.Get<bool>({})) {
|
||||
evaluated_[tuple_select] = on_true.CloneToUnique();
|
||||
evaluated_[tuple_select] = on_true.Clone();
|
||||
} else {
|
||||
evaluated_[tuple_select] = on_false.CloneToUnique();
|
||||
evaluated_[tuple_select] = on_false.Clone();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1108,7 +1135,7 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
|
||||
HloComputation* cond_comp = while_hlo->while_condition();
|
||||
HloComputation* body_comp = while_hlo->while_body();
|
||||
// Initialize the loop carried valued with the input to the While instruction.
|
||||
auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).CloneToUnique();
|
||||
auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone();
|
||||
bool keep_going = true;
|
||||
int64 iteration_count = 0;
|
||||
HloEvaluator cond_evaluator(max_loop_iterations_);
|
||||
@ -1118,13 +1145,13 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
|
||||
return InvalidArgument("Loop %s exceeded loop iteration limit (%d).",
|
||||
while_hlo->name(), max_loop_iterations_);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate<Literal*>(
|
||||
*cond_comp, {lcv.get()}));
|
||||
keep_going = cond_val->GetFirstElement<bool>();
|
||||
TF_ASSIGN_OR_RETURN(auto cond_val,
|
||||
cond_evaluator.Evaluate<Literal*>(*cond_comp, {&lcv}));
|
||||
keep_going = cond_val.GetFirstElement<bool>();
|
||||
if (keep_going) {
|
||||
TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate<Literal*>(
|
||||
*body_comp, {lcv.get()}));
|
||||
VLOG(3) << "Loop iteration result: " << body_val->ToString();
|
||||
*body_comp, {&lcv}));
|
||||
VLOG(3) << "Loop iteration result: " << body_val.ToString();
|
||||
lcv = std::move(body_val);
|
||||
cond_evaluator.ResetVisitStates();
|
||||
loop_body_evaluator.ResetVisitStates();
|
||||
@ -1139,9 +1166,9 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
|
||||
// hoops to make this work.
|
||||
namespace {
|
||||
template <typename KeyType, typename ValueType>
|
||||
StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
|
||||
HloInstruction* sort, const Literal& keys_literal,
|
||||
const Literal& values_literal) {
|
||||
StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort,
|
||||
const Literal& keys_literal,
|
||||
const Literal& values_literal) {
|
||||
auto rank = ShapeUtil::Rank(keys_literal.shape());
|
||||
TF_RET_CHECK(
|
||||
ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape()))
|
||||
@ -1179,57 +1206,55 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
|
||||
result_keys.push_back(key_value.first);
|
||||
result_values.push_back(key_value.second);
|
||||
}
|
||||
auto result_keys_literal = absl::make_unique<Literal>(keys_literal.shape());
|
||||
result_keys_literal->PopulateR1(absl::Span<const KeyType>(result_keys));
|
||||
auto result_values_literal =
|
||||
absl::make_unique<Literal>(values_literal.shape());
|
||||
result_values_literal->PopulateR1(
|
||||
Literal result_keys_literal(keys_literal.shape());
|
||||
result_keys_literal.PopulateR1(absl::Span<const KeyType>(result_keys));
|
||||
Literal result_values_literal(values_literal.shape());
|
||||
result_values_literal.PopulateR1(
|
||||
absl::Span<const ValueType>(result_values));
|
||||
return std::make_pair(std::move(result_keys_literal),
|
||||
std::move(result_values_literal));
|
||||
};
|
||||
|
||||
std::unique_ptr<Literal> result_tuple;
|
||||
Literal result_tuple;
|
||||
if (rank == 1) {
|
||||
auto result_pair = sort_r1(keys_literal, values_literal);
|
||||
result_tuple = LiteralUtil::MakeTuple(
|
||||
{result_pair.first.get(), result_pair.second.get()});
|
||||
result_tuple =
|
||||
LiteralUtil::MakeTuple({&result_pair.first, &result_pair.second});
|
||||
} else {
|
||||
// For R2 sort, the desired semantics are to sort each matrix row
|
||||
// independently.
|
||||
auto keys_result_literal = absl::make_unique<Literal>(keys_literal.shape());
|
||||
auto values_result_literal =
|
||||
absl::make_unique<Literal>(values_literal.shape());
|
||||
Literal keys_result_literal(keys_literal.shape());
|
||||
Literal values_result_literal(values_literal.shape());
|
||||
int64 r1_length = keys_literal.shape().dimensions(1);
|
||||
for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) {
|
||||
TF_ASSIGN_OR_RETURN(auto keys_r1_slice,
|
||||
keys_literal.Slice({row, 0}, {row + 1, r1_length})
|
||||
->Reshape({r1_length}));
|
||||
.Reshape({r1_length}));
|
||||
TF_ASSIGN_OR_RETURN(auto values_r1_slice,
|
||||
values_literal.Slice({row, 0}, {row + 1, r1_length})
|
||||
->Reshape({r1_length}));
|
||||
auto r1_result_pair = sort_r1(*keys_r1_slice, *values_r1_slice);
|
||||
.Reshape({r1_length}));
|
||||
auto r1_result_pair = sort_r1(keys_r1_slice, values_r1_slice);
|
||||
TF_ASSIGN_OR_RETURN(auto sorted_keys,
|
||||
r1_result_pair.first->Reshape({1, r1_length}));
|
||||
r1_result_pair.first.Reshape({1, r1_length}));
|
||||
TF_ASSIGN_OR_RETURN(auto sorted_values,
|
||||
r1_result_pair.second->Reshape({1, r1_length}));
|
||||
TF_RETURN_IF_ERROR(keys_result_literal->CopySliceFrom(
|
||||
*sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
|
||||
TF_RETURN_IF_ERROR(values_result_literal->CopySliceFrom(
|
||||
*sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
|
||||
r1_result_pair.second.Reshape({1, r1_length}));
|
||||
TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom(
|
||||
sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
|
||||
TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom(
|
||||
sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
|
||||
}
|
||||
result_tuple = LiteralUtil::MakeTuple(
|
||||
{keys_result_literal.get(), values_result_literal.get()});
|
||||
result_tuple =
|
||||
LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal});
|
||||
}
|
||||
|
||||
VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString();
|
||||
VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
|
||||
return std::move(result_tuple);
|
||||
}
|
||||
|
||||
template <typename KeyType>
|
||||
StatusOr<std::unique_ptr<Literal>> EvaluateSortCurried(
|
||||
HloInstruction* sort, const Literal& keys_literal,
|
||||
const Literal& values_literal) {
|
||||
StatusOr<Literal> EvaluateSortCurried(HloInstruction* sort,
|
||||
const Literal& keys_literal,
|
||||
const Literal& values_literal) {
|
||||
switch (sort->operand(1)->shape().element_type()) {
|
||||
case F32:
|
||||
return EvaluateSortInternal<KeyType, float>(sort, keys_literal,
|
||||
@ -1248,9 +1273,9 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortCurried(
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> EvaluateSort(HloInstruction* sort,
|
||||
const Literal& keys_literal,
|
||||
const Literal& values_literal) {
|
||||
StatusOr<Literal> EvaluateSort(HloInstruction* sort,
|
||||
const Literal& keys_literal,
|
||||
const Literal& values_literal) {
|
||||
switch (sort->operand(0)->shape().element_type()) {
|
||||
case F32:
|
||||
return EvaluateSortCurried<float>(sort, keys_literal, values_literal);
|
||||
@ -1319,28 +1344,14 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) {
|
||||
|
||||
// Explicit instantiation of templatized Evaluate* methods.
|
||||
//
|
||||
template StatusOr<std::unique_ptr<Literal>>
|
||||
HloEvaluator::Evaluate<const Literal*>(
|
||||
template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
|
||||
const HloModule& module, absl::Span<const Literal* const> arg_literals);
|
||||
template StatusOr<std::unique_ptr<Literal>>
|
||||
HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
|
||||
const HloModule& module,
|
||||
absl::Span<const std::unique_ptr<Literal>> arg_literals);
|
||||
|
||||
template StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate<
|
||||
const Literal*>(const HloComputation& computation,
|
||||
absl::Span<const Literal* const> arg_literals);
|
||||
template StatusOr<std::unique_ptr<Literal>>
|
||||
HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
|
||||
template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
|
||||
const HloComputation& computation,
|
||||
absl::Span<const std::unique_ptr<Literal>> arg_literals);
|
||||
absl::Span<const Literal* const> arg_literals);
|
||||
|
||||
template StatusOr<std::unique_ptr<Literal>>
|
||||
HloEvaluator::Evaluate<const Literal*>(
|
||||
template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
|
||||
HloInstruction* instruction, absl::Span<const Literal* const> arg_literals);
|
||||
template StatusOr<std::unique_ptr<Literal>>
|
||||
HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
|
||||
HloInstruction* instruction,
|
||||
absl::Span<const std::unique_ptr<Literal>> arg_literals);
|
||||
|
||||
} // namespace xla
|
||||
|
@ -47,11 +47,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
|
||||
// Precondition: The indices of arg_literals correspond to the parameter
|
||||
// numbers of the HLO parameters in the computation. See comment below for an
|
||||
// example.
|
||||
// `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
|
||||
// `LiteralPtr` accepts either Literal or const Literal*
|
||||
// type.
|
||||
template <typename LiteralPtr>
|
||||
StatusOr<std::unique_ptr<Literal>> Evaluate(
|
||||
const HloModule& module, absl::Span<const LiteralPtr> arg_literals);
|
||||
StatusOr<Literal> Evaluate(const HloModule& module,
|
||||
absl::Span<const LiteralPtr> arg_literals);
|
||||
|
||||
// Evaluates an HLO computation and an array of pointers to literals.
|
||||
// Returns the evaluated result as a literal if successful.
|
||||
@ -69,12 +69,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
|
||||
// where Parameter0 has parameter_number 0 and Parameter1 has parameter_number
|
||||
// 1 in this computation. The input literals array will then have its first
|
||||
// literal map to Parameter0 and the second map to Parameter1.
|
||||
// `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
|
||||
// `LiteralPtr` accepts either Literal or const Literal*
|
||||
// type.
|
||||
template <typename LiteralPtr>
|
||||
StatusOr<std::unique_ptr<Literal>> Evaluate(
|
||||
const HloComputation& computation,
|
||||
absl::Span<const LiteralPtr> arg_literals);
|
||||
StatusOr<Literal> Evaluate(const HloComputation& computation,
|
||||
absl::Span<const LiteralPtr> arg_literals);
|
||||
|
||||
// Evaluates a single HLO instruction and an array of pointers to literals.
|
||||
// Return the evaluated result as literal if successful.
|
||||
@ -82,42 +81,43 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
|
||||
// 1. argument literals correspond to the input instruction's parameters in
|
||||
// their post-ordering.
|
||||
// 2. the instruction's operands must be of either Parameter or Constant type.
|
||||
// `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
|
||||
// `LiteralPtr` accepts either Literal or const Literal*
|
||||
// type.
|
||||
template <typename LiteralPtr>
|
||||
StatusOr<std::unique_ptr<Literal>> Evaluate(
|
||||
HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals);
|
||||
StatusOr<Literal> Evaluate(HloInstruction* instruction,
|
||||
absl::Span<const LiteralPtr> arg_literals);
|
||||
|
||||
// Evaluates a single HLO instruction with constant operands.
|
||||
// Returns the evaluated result as literal if successful.
|
||||
// Precondition:
|
||||
// 1. all operands of the input instruction are constants.
|
||||
// 2. the instruction is not a Parameter operation.
|
||||
StatusOr<std::unique_ptr<Literal>> Evaluate(HloInstruction* instruction);
|
||||
StatusOr<Literal> Evaluate(HloInstruction* instruction);
|
||||
|
||||
// Same as Evaluate, except returning nullptr on error.
|
||||
std::unique_ptr<Literal> TryEvaluate(HloInstruction* instruction);
|
||||
// Same as Evaluate, except returning false on error and accepts an output
|
||||
// pointer.
|
||||
bool TryEvaluate(HloInstruction* instruction, Literal* result);
|
||||
|
||||
// Evaluates a single HLO instruction, substituting the given literals for
|
||||
// some of the instruction's operands.
|
||||
//
|
||||
// For example, given instruction = op(A, B, C) and the map
|
||||
// {A = x, C = y}, this evaluates op(x, B, y).
|
||||
StatusOr<std::unique_ptr<Literal>> EvaluateWithSubstitutions(
|
||||
StatusOr<Literal> EvaluateWithSubstitutions(
|
||||
const HloInstruction* instruction,
|
||||
const std::unordered_map<const HloInstruction*, const Literal*>&
|
||||
substitutions);
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseBinaryOp(
|
||||
HloOpcode opcode, const Literal& lhs, const Literal& rhs);
|
||||
StatusOr<Literal> EvaluateElementwiseBinaryOp(HloOpcode opcode,
|
||||
const Literal& lhs,
|
||||
const Literal& rhs);
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseUnaryOp(
|
||||
HloOpcode opcode, const Literal& operand);
|
||||
StatusOr<Literal> EvaluateElementwiseUnaryOp(HloOpcode opcode,
|
||||
const Literal& operand);
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> EvaluateDotOp(
|
||||
const DotDimensionNumbers& dim_numbers,
|
||||
const PrecisionConfig& precision_config, const Literal& lhs,
|
||||
const Literal& rhs);
|
||||
StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers,
|
||||
const PrecisionConfig& precision_config,
|
||||
const Literal& lhs, const Literal& rhs);
|
||||
|
||||
protected:
|
||||
// Make HloEvaluatorTypedVisitor a friend because it is logically part of this
|
||||
@ -197,7 +197,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
|
||||
auto it = evaluated_.find(hlo);
|
||||
CHECK(it != evaluated_.end())
|
||||
<< "could not find evaluated value for: " << hlo->ToString();
|
||||
return *(it->second);
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Tracks the HLO instruction and its evaluated literal result.
|
||||
@ -205,12 +205,13 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
|
||||
// that are no longer a parent for any other subsequent instruction in
|
||||
// post-orderring.
|
||||
// Must be cleared for each evaluation.
|
||||
tensorflow::gtl::FlatMap<const HloInstruction*, std::unique_ptr<Literal>>
|
||||
evaluated_;
|
||||
// Storing Literal in place require the container to have pointer stability so
|
||||
// we cannot use FlatMap any more.
|
||||
std::unordered_map<const HloInstruction*, Literal> evaluated_;
|
||||
|
||||
private:
|
||||
template <typename ReturnT, typename NativeT>
|
||||
static StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
|
||||
static StatusOr<Literal> ElementWiseUnaryOpImpl(
|
||||
HloInstruction* instruction,
|
||||
const std::function<ReturnT(NativeT)>& unary_op,
|
||||
const Literal& operand_literal) {
|
||||
@ -227,9 +228,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
|
||||
ShapeUtil::HumanString(operand->shape()));
|
||||
}
|
||||
|
||||
auto result = absl::make_unique<Literal>(shape);
|
||||
Literal result(shape);
|
||||
TF_RETURN_IF_ERROR(
|
||||
result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
|
||||
result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
|
||||
return unary_op(operand_literal.Get<NativeT>(multi_index));
|
||||
}));
|
||||
return std::move(result);
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -246,15 +246,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
Status HandleConvert(HloInstruction* convert) override {
|
||||
const HloInstruction* operand = convert->operand(0);
|
||||
TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
|
||||
TF_ASSIGN_OR_RETURN(Literal result,
|
||||
parent_->GetEvaluatedLiteralFor(operand).Convert(
|
||||
convert->shape().element_type()));
|
||||
|
||||
if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
|
||||
if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) {
|
||||
parent_->evaluated_[convert] = std::move(result);
|
||||
} else {
|
||||
parent_->evaluated_[convert] =
|
||||
result->Relayout(convert->shape().layout());
|
||||
parent_->evaluated_[convert] = result.Relayout(convert->shape().layout());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -262,15 +261,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
Status HandleBitcastConvert(HloInstruction* convert) override {
|
||||
const HloInstruction* operand = convert->operand(0);
|
||||
TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
|
||||
TF_ASSIGN_OR_RETURN(Literal result,
|
||||
parent_->GetEvaluatedLiteralFor(operand).BitcastConvert(
|
||||
convert->shape().element_type()));
|
||||
|
||||
if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
|
||||
if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) {
|
||||
parent_->evaluated_[convert] = std::move(result);
|
||||
} else {
|
||||
parent_->evaluated_[convert] =
|
||||
result->Relayout(convert->shape().layout());
|
||||
parent_->evaluated_[convert] = result.Relayout(convert->shape().layout());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -978,10 +976,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
<< ShapeUtil::HumanString(inferred_return_shape);
|
||||
|
||||
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
|
||||
auto result = absl::make_unique<Literal>(result_shape);
|
||||
Literal result(result_shape);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
result->Populate<ReturnT>([&](absl::Span<const int64> out_index) {
|
||||
result.Populate<ReturnT>([&](absl::Span<const int64> out_index) {
|
||||
std::vector<int64> from_index(out_index.begin(), out_index.end());
|
||||
for (const int64 dim : reverse_dimensions) {
|
||||
from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim];
|
||||
@ -1157,8 +1155,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return static_cast<ReturnT>(result_val);
|
||||
};
|
||||
|
||||
auto result = absl::make_unique<Literal>(result_shape);
|
||||
TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func));
|
||||
Literal result(result_shape);
|
||||
TF_RETURN_IF_ERROR(result.PopulateParallel<ReturnT>(func));
|
||||
|
||||
parent_->evaluated_[conv] = std::move(result);
|
||||
return Status::OK();
|
||||
@ -1231,9 +1229,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
}
|
||||
|
||||
auto result = absl::make_unique<Literal>(dot->shape());
|
||||
Literal result(dot->shape());
|
||||
TF_RETURN_IF_ERROR(
|
||||
result->Populate<ReturnT>([&](absl::Span<const int64> result_index) {
|
||||
result.Populate<ReturnT>([&](absl::Span<const int64> result_index) {
|
||||
ElementwiseT result_val = static_cast<ElementwiseT>(0);
|
||||
|
||||
for (int64 i = 0; i < result_index.size(); i++) {
|
||||
@ -1280,8 +1278,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
// Create new HLO of padded shape with padding value.
|
||||
ReturnT scalar =
|
||||
parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({});
|
||||
auto result = absl::make_unique<Literal>(pad->shape());
|
||||
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
|
||||
Literal result(pad->shape());
|
||||
TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
|
||||
[&scalar](absl::Span<const int64> multi_index) { return scalar; }));
|
||||
|
||||
const Literal& evaluated_operand =
|
||||
@ -1289,7 +1287,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()),
|
||||
0);
|
||||
std::vector<int64> target_index(ShapeUtil::Rank(result->shape()), 0);
|
||||
std::vector<int64> target_index(ShapeUtil::Rank(result.shape()), 0);
|
||||
|
||||
// Loop through each element of the operand, assign them to the
|
||||
// corresponding index of the resulting padded literal.
|
||||
@ -1311,8 +1309,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
result->Set<ReturnT>(target_index,
|
||||
evaluated_operand.Get<ReturnT>(input_index));
|
||||
result.Set<ReturnT>(target_index,
|
||||
evaluated_operand.Get<ReturnT>(input_index));
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -1439,16 +1437,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
StatusOr<std::unique_ptr<Literal>> MapImpl(HloInstruction* map) {
|
||||
StatusOr<Literal> MapImpl(HloInstruction* map) {
|
||||
auto operands = map->operands();
|
||||
HloComputation* computation = map->to_apply();
|
||||
|
||||
auto result = absl::make_unique<Literal>(map->shape());
|
||||
Literal result(map->shape());
|
||||
|
||||
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
|
||||
TF_RETURN_IF_ERROR(
|
||||
result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
|
||||
std::vector<std::unique_ptr<Literal>> arg_literals;
|
||||
result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
|
||||
std::vector<Literal> arg_literals;
|
||||
arg_literals.reserve(operands.size());
|
||||
|
||||
// Construct scalar literal parameters to be passed to the map
|
||||
@ -1463,16 +1461,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
arg_literals.push_back(std::move(curr_val_literal));
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> computed_result =
|
||||
embedded_evaluator
|
||||
.Evaluate<std::unique_ptr<Literal>>(*computation,
|
||||
arg_literals)
|
||||
Literal computed_result =
|
||||
embedded_evaluator.Evaluate<Literal>(*computation, arg_literals)
|
||||
.ConsumeValueOrDie();
|
||||
// Clear visit states so that the we can use the evaluate again on
|
||||
// the same computation.
|
||||
embedded_evaluator.ResetVisitStates();
|
||||
|
||||
return computed_result->Get<ReturnT>({});
|
||||
return computed_result.Get<ReturnT>({});
|
||||
}));
|
||||
return std::move(result);
|
||||
}
|
||||
@ -1557,9 +1553,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
[](const ReturnT& a, const ReturnT& b) {
|
||||
return SafeLess<ReturnT>(a, b);
|
||||
});
|
||||
auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
|
||||
result_literal->PopulateR1(absl::Span<const ReturnT>(result_data));
|
||||
VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
|
||||
Literal result_literal(keys_literal.shape());
|
||||
result_literal.PopulateR1(absl::Span<const ReturnT>(result_data));
|
||||
VLOG(3) << "HandleSort result_literal: " << result_literal.ToString();
|
||||
return result_literal;
|
||||
};
|
||||
|
||||
@ -1568,16 +1564,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
} else {
|
||||
// For R2 sort, the desired semantics are to sort each matrix row
|
||||
// independently.
|
||||
auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
|
||||
Literal result_literal(keys_literal.shape());
|
||||
int64 r1_length = keys->shape().dimensions(1);
|
||||
for (int64 row = 0; row < keys->shape().dimensions(0); ++row) {
|
||||
TF_ASSIGN_OR_RETURN(auto r1_slice,
|
||||
keys_literal.Slice({row, 0}, {row + 1, r1_length})
|
||||
->Reshape({r1_length}));
|
||||
auto r1_result = sort_r1(*r1_slice);
|
||||
TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length}));
|
||||
TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
|
||||
*r1_result, {0, 0}, {row, 0}, {1, r1_length}));
|
||||
.Reshape({r1_length}));
|
||||
auto r1_result = sort_r1(r1_slice);
|
||||
TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length}));
|
||||
TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
|
||||
r1_result, {0, 0}, {row, 0}, {1, r1_length}));
|
||||
}
|
||||
parent_->evaluated_[sort] = std::move(result_literal);
|
||||
}
|
||||
@ -1651,9 +1647,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
|
||||
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
|
||||
absl::InlinedVector<std::unique_ptr<Literal>, 1> results(num_args);
|
||||
absl::InlinedVector<Literal, 1> results(num_args);
|
||||
for (int64 i = 0; i < num_args; ++i) {
|
||||
results[i] = absl::make_unique<Literal>(result_shape);
|
||||
results[i] = Literal(result_shape);
|
||||
}
|
||||
|
||||
Status eval_status;
|
||||
@ -1667,7 +1663,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
|
||||
for (int64 input = 0; input < num_args; ++input) {
|
||||
TF_RETURN_IF_ERROR(results[input]->Populate<ReturnT>(
|
||||
TF_RETURN_IF_ERROR(results[input].Populate<ReturnT>(
|
||||
[&](absl::Span<const int64> multi_index) {
|
||||
if (!eval_status.ok()) {
|
||||
return init_scalars[input];
|
||||
@ -1703,8 +1699,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
|
||||
// Evaluate computation with specified literal operands.
|
||||
absl::InlinedVector<std::unique_ptr<Literal>, 1>
|
||||
embedded_operands;
|
||||
absl::InlinedVector<Literal, 1> embedded_operands;
|
||||
for (ReturnT value : result_values) {
|
||||
embedded_operands.push_back(
|
||||
LiteralUtil::CreateR0<ReturnT>(value));
|
||||
@ -1717,11 +1712,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
embedded_operands.size());
|
||||
std::transform(embedded_operands.begin(), embedded_operands.end(),
|
||||
embedded_operands_ptrs.begin(),
|
||||
[](const std::unique_ptr<Literal>& ptr) {
|
||||
return ptr.get();
|
||||
});
|
||||
[](Literal& literal) { return &literal; });
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result,
|
||||
TF_ASSIGN_OR_RETURN(Literal computed_result,
|
||||
embedded_evaluator.Evaluate<const Literal*>(
|
||||
*function, embedded_operands_ptrs));
|
||||
// Clear visit states so that we can use the evaluator again on
|
||||
@ -1729,10 +1722,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
embedded_evaluator.ResetVisitStates();
|
||||
// Assign computed result to result_val.
|
||||
if (!has_tuple_output) {
|
||||
result_values[0] = computed_result->Get<ReturnT>({});
|
||||
result_values[0] = computed_result.Get<ReturnT>({});
|
||||
} else {
|
||||
for (int64 i = 0; i < num_args; ++i) {
|
||||
result_values[i] = computed_result->Get<ReturnT>(
|
||||
result_values[i] = computed_result.Get<ReturnT>(
|
||||
/*multi_index=*/{}, /*shape_index=*/{i});
|
||||
}
|
||||
}
|
||||
@ -1748,9 +1741,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
if (!has_tuple_output) {
|
||||
parent_->evaluated_[reduce] = std::move(results[0]);
|
||||
} else {
|
||||
auto tuple_result = absl::make_unique<Literal>(reduce->shape());
|
||||
Literal tuple_result(reduce->shape());
|
||||
for (int64 i = 0; i < num_args; ++i) {
|
||||
TF_CHECK_OK(tuple_result->MoveFrom(std::move(*results[i]), {i}));
|
||||
TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i}));
|
||||
}
|
||||
parent_->evaluated_[reduce] = std::move(tuple_result);
|
||||
}
|
||||
@ -1781,10 +1774,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
|
||||
auto init_scalar = init_literal.Get<ReturnT>({});
|
||||
|
||||
auto result = absl::make_unique<Literal>(select_and_scatter->shape());
|
||||
Literal result(select_and_scatter->shape());
|
||||
|
||||
// Initialize result array with the init value.
|
||||
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
|
||||
TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
|
||||
[&](absl::Span<const int64> output_index) { return init_scalar; }));
|
||||
|
||||
std::vector<int64> window_dimension_sizes;
|
||||
@ -1834,15 +1827,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
selected_val = curr_val;
|
||||
selected_index = operand_index;
|
||||
}
|
||||
curr_val_literal->Set({}, curr_val);
|
||||
selected_val_literal->Set({}, *selected_val);
|
||||
std::unique_ptr<Literal> computed_result =
|
||||
curr_val_literal.Set({}, curr_val);
|
||||
selected_val_literal.Set({}, *selected_val);
|
||||
Literal computed_result =
|
||||
embedded_evaluator
|
||||
.Evaluate<const Literal*>(
|
||||
*select,
|
||||
{selected_val_literal.get(), curr_val_literal.get()})
|
||||
*select, {&selected_val_literal, &curr_val_literal})
|
||||
.ConsumeValueOrDie();
|
||||
bool selected = !computed_result->Get<bool>({});
|
||||
bool selected = !computed_result.Get<bool>({});
|
||||
if (selected) {
|
||||
selected_val = curr_val;
|
||||
selected_index = operand_index;
|
||||
@ -1856,16 +1848,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
if (std::equal(operand_index.begin(), operand_index.end(),
|
||||
selected_index->begin())) {
|
||||
auto source = source_literal.Get<ReturnT>(source_index);
|
||||
auto scattered = result->Get<ReturnT>(operand_index);
|
||||
source_literal_scatter->Set({}, source);
|
||||
scattered_literal->Set({}, scattered);
|
||||
std::unique_ptr<Literal> computed_result =
|
||||
auto scattered = result.Get<ReturnT>(operand_index);
|
||||
source_literal_scatter.Set({}, source);
|
||||
scattered_literal.Set({}, scattered);
|
||||
Literal computed_result =
|
||||
embedded_evaluator
|
||||
.Evaluate<const Literal*>(*scatter,
|
||||
{source_literal_scatter.get(),
|
||||
scattered_literal.get()})
|
||||
.Evaluate<const Literal*>(
|
||||
*scatter,
|
||||
{&source_literal_scatter, &scattered_literal})
|
||||
.ConsumeValueOrDie();
|
||||
result->Set(operand_index, computed_result->Get<ReturnT>({}));
|
||||
result.Set(operand_index, computed_result.Get<ReturnT>({}));
|
||||
// Clear visit states so that the we can use the evaluator again
|
||||
// on the same computation.
|
||||
embedded_evaluator.ResetVisitStates();
|
||||
@ -1916,10 +1908,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape()));
|
||||
|
||||
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
|
||||
auto result = absl::make_unique<Literal>(reduce_window->shape());
|
||||
Literal result(reduce_window->shape());
|
||||
// For each resulting dimension, calculate and assign computed value.
|
||||
TF_RETURN_IF_ERROR(
|
||||
result->Populate<ReturnT>([&](absl::Span<const int64> output_index) {
|
||||
result.Populate<ReturnT>([&](absl::Span<const int64> output_index) {
|
||||
ReturnT result_val = init_scalar;
|
||||
|
||||
std::fill(window_index.begin(), window_index.end(), 0);
|
||||
@ -1935,18 +1927,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
LiteralUtil::CreateR0<ReturnT>(curr_val);
|
||||
const auto result_val_literal =
|
||||
LiteralUtil::CreateR0<ReturnT>(result_val);
|
||||
std::unique_ptr<Literal> computed_result =
|
||||
Literal computed_result =
|
||||
embedded_evaluator
|
||||
.Evaluate<const Literal*>(
|
||||
*function,
|
||||
{result_val_literal.get(), curr_val_literal.get()})
|
||||
*function, {&result_val_literal, &curr_val_literal})
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
// Clear visit states so that the we can use the evaluate again
|
||||
// on the same computation.
|
||||
embedded_evaluator.ResetVisitStates();
|
||||
|
||||
result_val = computed_result->Get<ReturnT>({});
|
||||
result_val = computed_result.Get<ReturnT>({});
|
||||
});
|
||||
|
||||
return result_val;
|
||||
@ -1961,7 +1952,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
// literal (if there is one) to `reshaped_indices`.
|
||||
StatusOr<std::reference_wrapper<const Literal>> ReshapedScatterIndices(
|
||||
int64 index_vector_dim, const Literal& indices,
|
||||
std::unique_ptr<Literal>* reshaped_indices) {
|
||||
Literal* reshaped_indices) {
|
||||
if (indices.shape().dimensions_size() != index_vector_dim) {
|
||||
return std::cref(indices);
|
||||
}
|
||||
@ -1970,7 +1961,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
indices.shape().dimensions().end());
|
||||
new_shape.push_back(1);
|
||||
TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape));
|
||||
return std::cref(**reshaped_indices);
|
||||
return std::cref(*reshaped_indices);
|
||||
}
|
||||
|
||||
// Returns an ShapeUtil::IndexIterationSpace that iterates over the update
|
||||
@ -2230,7 +2221,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
scatter->scatter_dimension_numbers();
|
||||
const Literal& operand =
|
||||
parent_->GetEvaluatedLiteralFor(scatter->operand(0));
|
||||
std::unique_ptr<Literal> reshaped_scatter_indices;
|
||||
Literal reshaped_scatter_indices;
|
||||
TF_ASSIGN_OR_RETURN(const Literal& scatter_indices,
|
||||
ReshapedScatterIndices(dim_numbers.index_vector_dim(),
|
||||
parent_->GetEvaluatedLiteralFor(
|
||||
@ -2260,7 +2251,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
// Initialize the result with the operand. This makes it easier to handle
|
||||
// the updates even when the indices are repeated.
|
||||
std::unique_ptr<Literal> result = operand.CloneToUnique();
|
||||
Literal result = operand.Clone();
|
||||
HloEvaluator embedded_evaluator;
|
||||
auto scatter_inner_loop_body =
|
||||
[&](absl::Span<const int64> update_window_index,
|
||||
@ -2299,19 +2290,19 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
|
||||
auto result_value_literal =
|
||||
LiteralUtil::CreateR0<ReturnT>(result->Get<ReturnT>(input_index));
|
||||
LiteralUtil::CreateR0<ReturnT>(result.Get<ReturnT>(input_index));
|
||||
auto update_value_literal =
|
||||
LiteralUtil::CreateR0<ReturnT>(updates.Get<ReturnT>(update_index));
|
||||
std::unique_ptr<Literal> updated_result =
|
||||
Literal updated_result =
|
||||
embedded_evaluator
|
||||
.Evaluate<const Literal*>(
|
||||
*scatter->to_apply(),
|
||||
{result_value_literal.get(), update_value_literal.get()})
|
||||
{&result_value_literal, &update_value_literal})
|
||||
.ConsumeValueOrDie();
|
||||
// Clear visit states so that the we can use the evaluate again on the
|
||||
// same computation.
|
||||
embedded_evaluator.ResetVisitStates();
|
||||
result->Set<ReturnT>(input_index, updated_result->Get<ReturnT>({}));
|
||||
result.Set<ReturnT>(input_index, updated_result.Get<ReturnT>({}));
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -2361,7 +2352,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
auto result = LiteralUtil::CreateFromDimensions(
|
||||
shape.element_type(), AsInt64Slice(shape.dimensions()));
|
||||
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
|
||||
TF_RETURN_IF_ERROR(result.Populate<ReturnT>(func));
|
||||
parent_->evaluated_[slice] = std::move(result);
|
||||
return Status::OK();
|
||||
}
|
||||
@ -2575,7 +2566,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
if (ShapeUtil::Rank(iota->shape()) > 1) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[iota],
|
||||
result->Broadcast(iota->shape(), {iota->iota_dimension()}));
|
||||
result.Broadcast(iota->shape(), {iota->iota_dimension()}));
|
||||
} else {
|
||||
TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1);
|
||||
parent_->evaluated_[iota] = std::move(result);
|
||||
@ -2645,9 +2636,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
|
||||
template <typename IndexT>
|
||||
StatusOr<std::unique_ptr<Literal>> DynamicSlice(
|
||||
const Literal& operand_literal, const Literal& start_indices_literal,
|
||||
const Shape& result_shape) {
|
||||
StatusOr<Literal> DynamicSlice(const Literal& operand_literal,
|
||||
const Literal& start_indices_literal,
|
||||
const Shape& result_shape) {
|
||||
auto start_indices_typed = start_indices_literal.data<IndexT>();
|
||||
std::vector<int64> start(start_indices_typed.begin(),
|
||||
start_indices_typed.end());
|
||||
@ -2660,9 +2651,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
|
||||
std::vector<int64> operand_indices(start.size());
|
||||
auto result = absl::make_unique<Literal>(result_shape);
|
||||
Literal result(result_shape);
|
||||
TF_RETURN_IF_ERROR(
|
||||
result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
|
||||
result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
|
||||
for (int64 i = 0; i < operand_indices.size(); ++i) {
|
||||
CHECK_GE(multi_index[i] + start[i], 0);
|
||||
operand_indices[i] = multi_index[i] + start[i];
|
||||
@ -2676,12 +2667,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
|
||||
template <typename IndexT>
|
||||
StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice(
|
||||
const Literal& operand_literal, const Literal& update_literal,
|
||||
const Literal& start_indices_literal) {
|
||||
auto result = operand_literal.CloneToUnique();
|
||||
StatusOr<Literal> DynamicUpdateSlice(const Literal& operand_literal,
|
||||
const Literal& update_literal,
|
||||
const Literal& start_indices_literal) {
|
||||
auto result = operand_literal.Clone();
|
||||
auto start_indices_typed = start_indices_literal.data<IndexT>();
|
||||
const auto rank = ShapeUtil::Rank(result->shape());
|
||||
const auto rank = ShapeUtil::Rank(result.shape());
|
||||
std::vector<int64> start(start_indices_typed.begin(),
|
||||
start_indices_typed.end());
|
||||
// Clamp the update start indices so the slice is in-bounds w.r.t the
|
||||
@ -2689,15 +2680,15 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
start[i] = std::min<int64>(
|
||||
std::max<int64>(0, start[i]),
|
||||
result->shape().dimensions(i) - update_literal.shape().dimensions(i));
|
||||
result.shape().dimensions(i) - update_literal.shape().dimensions(i));
|
||||
}
|
||||
std::vector<int64> result_index(rank, 0);
|
||||
|
||||
auto func = [&](absl::Span<const int64> update_index) {
|
||||
std::transform(update_index.begin(), update_index.end(), start.begin(),
|
||||
result_index.begin(), std::plus<int64>());
|
||||
result->Set<ReturnT>(result_index,
|
||||
update_literal.Get<ReturnT>(update_index));
|
||||
result.Set<ReturnT>(result_index,
|
||||
update_literal.Get<ReturnT>(update_index));
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -2710,7 +2701,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp(
|
||||
StatusOr<Literal> ElementWiseUnaryOp(
|
||||
HloInstruction* instruction,
|
||||
const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
|
||||
const Literal& operand_literal =
|
||||
@ -2723,7 +2714,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return std::move(result_literal);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp(
|
||||
StatusOr<Literal> ElementWiseBinaryOp(
|
||||
HloInstruction* instruction,
|
||||
const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
|
||||
binary_op) {
|
||||
@ -2745,10 +2736,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
|
||||
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
|
||||
|
||||
auto result = absl::make_unique<Literal>(shape);
|
||||
Literal result(shape);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
|
||||
result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
|
||||
return ConvertBinaryFunction(binary_op)(
|
||||
lhs_literal.Get<ReturnT>(multi_index),
|
||||
rhs_literal.Get<ReturnT>(multi_index));
|
||||
@ -2757,7 +2748,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
|
||||
template <typename LhsType, typename RhsType, typename EhsType>
|
||||
StatusOr<std::unique_ptr<Literal>> ElementwiseTernaryOp(
|
||||
StatusOr<Literal> ElementwiseTernaryOp(
|
||||
HloInstruction* instruction,
|
||||
const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) {
|
||||
const auto shape = instruction->shape();
|
||||
@ -2782,10 +2773,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
|
||||
const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs);
|
||||
|
||||
auto result = absl::make_unique<Literal>(shape);
|
||||
Literal result(shape);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
|
||||
result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
|
||||
return ternary_op(lhs_literal.Get<LhsType>(multi_index),
|
||||
rhs_literal.Get<RhsType>(multi_index),
|
||||
ehs_literal.Get<EhsType>(multi_index));
|
||||
|
@ -250,7 +250,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
TF_RET_CHECK(proto.has_literal());
|
||||
TF_ASSIGN_OR_RETURN(auto literal,
|
||||
Literal::CreateFromProto(proto.literal()));
|
||||
instruction = CreateTrace(literal->GetR1U8AsString(), operands(0));
|
||||
instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kFusion: {
|
||||
@ -527,7 +527,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
|
||||
std::unique_ptr<Literal> literal) {
|
||||
Literal literal) {
|
||||
return absl::make_unique<HloConstantInstruction>(std::move(literal));
|
||||
}
|
||||
|
||||
|
@ -359,8 +359,7 @@ class HloInstruction {
|
||||
const string& name);
|
||||
|
||||
// Creates a literal constant instruction.
|
||||
static std::unique_ptr<HloInstruction> CreateConstant(
|
||||
std::unique_ptr<Literal> literal);
|
||||
static std::unique_ptr<HloInstruction> CreateConstant(Literal literal);
|
||||
|
||||
// Creates an Iota instruction.
|
||||
static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape,
|
||||
|
@ -845,8 +845,8 @@ std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
|
||||
shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_);
|
||||
}
|
||||
|
||||
HloConstantInstruction::HloConstantInstruction(std::unique_ptr<Literal> literal)
|
||||
: HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()),
|
||||
HloConstantInstruction::HloConstantInstruction(Literal literal)
|
||||
: HloInstruction(HloOpcode::kConstant, literal.shape()),
|
||||
literal_(std::move(literal)) {}
|
||||
|
||||
HloConstantInstruction::HloConstantInstruction(const Shape& shape)
|
||||
@ -854,7 +854,7 @@ HloConstantInstruction::HloConstantInstruction(const Shape& shape)
|
||||
|
||||
HloInstructionProto HloConstantInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloInstruction::ToProto();
|
||||
if (literal_ != nullptr) {
|
||||
if (literal_.has_value()) {
|
||||
*proto.mutable_literal() = literal_->ToProto();
|
||||
}
|
||||
return proto;
|
||||
@ -876,7 +876,7 @@ void HloConstantInstruction::RelayoutConstant(const Layout& new_layout,
|
||||
|
||||
if (!mutable_array_subshape->has_layout() ||
|
||||
!LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
|
||||
literal_ = literal_->Relayout(new_layout, shape_index);
|
||||
*literal_ = literal_->Relayout(new_layout, shape_index);
|
||||
*mutable_array_subshape->mutable_layout() = new_layout;
|
||||
}
|
||||
}
|
||||
@ -893,7 +893,8 @@ std::unique_ptr<HloInstruction>
|
||||
HloConstantInstruction::CloneWithNewOperandsImpl(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* context) const {
|
||||
return absl::make_unique<HloConstantInstruction>(literal_->CloneToUnique());
|
||||
CHECK(literal_.has_value());
|
||||
return absl::make_unique<HloConstantInstruction>(literal_->Clone());
|
||||
}
|
||||
|
||||
string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
|
||||
@ -901,7 +902,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
|
||||
CanonicalNameMap* canonical_name_map) const {
|
||||
string operands;
|
||||
// For constants, show the actual value in place of an empty operand list.
|
||||
if (literal_ != nullptr &&
|
||||
if (literal_.has_value() &&
|
||||
((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) ||
|
||||
options.print_large_constants())) {
|
||||
// Literal::ToString emits multidimensional arrays over multiple
|
||||
@ -936,7 +937,7 @@ HloTraceInstruction::HloTraceInstruction(const string& tag,
|
||||
|
||||
HloInstructionProto HloTraceInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloInstruction::ToProto();
|
||||
*proto.mutable_literal() = literal_->ToProto();
|
||||
*proto.mutable_literal() = literal_.ToProto();
|
||||
return proto;
|
||||
}
|
||||
|
||||
|
@ -580,13 +580,13 @@ class HloSliceInstruction : public HloInstruction {
|
||||
|
||||
class HloConstantInstruction : public HloInstruction {
|
||||
public:
|
||||
explicit HloConstantInstruction(std::unique_ptr<Literal> literal);
|
||||
explicit HloConstantInstruction(Literal literal);
|
||||
// Used when the literal is too large and dropped.
|
||||
explicit HloConstantInstruction(const Shape& shape);
|
||||
// Returns the literal associated with this instruction.
|
||||
const Literal& literal() const { return *literal_; }
|
||||
// Returns whether there is literal associated with this instruction.
|
||||
bool HasLiteral() const { return literal_ != nullptr; }
|
||||
bool HasLiteral() const { return literal_.has_value(); }
|
||||
// Returns a serialized representation of this instruction.
|
||||
HloInstructionProto ToProto() const override;
|
||||
|
||||
@ -610,15 +610,14 @@ class HloConstantInstruction : public HloInstruction {
|
||||
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* context) const override;
|
||||
// TODO(b/36360764): Remove unique_ptr wrapping.
|
||||
std::unique_ptr<Literal> literal_;
|
||||
absl::optional<Literal> literal_;
|
||||
};
|
||||
|
||||
class HloTraceInstruction : public HloInstruction {
|
||||
public:
|
||||
explicit HloTraceInstruction(const string& tag, HloInstruction* operand);
|
||||
// Returns a tag to be used in tracing.
|
||||
string TracingTag() const { return literal_->GetR1U8AsString(); }
|
||||
string TracingTag() const { return literal_.GetR1U8AsString(); }
|
||||
// Returns a serialized representation of this instruction.
|
||||
HloInstructionProto ToProto() const override;
|
||||
|
||||
@ -631,8 +630,7 @@ class HloTraceInstruction : public HloInstruction {
|
||||
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* context) const override;
|
||||
// TODO(b/36360764): Remove unique_ptr wrapping.
|
||||
std::unique_ptr<Literal> literal_;
|
||||
Literal literal_;
|
||||
};
|
||||
|
||||
class HloFusionInstruction : public HloInstruction {
|
||||
|
@ -105,16 +105,13 @@ class HloParser {
|
||||
string* root_name);
|
||||
bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
|
||||
bool ParseControlPredecessors(HloInstruction* instruction);
|
||||
bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
|
||||
bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
|
||||
bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
|
||||
const Shape& shape);
|
||||
bool ParseDenseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
|
||||
bool ParseSparseLiteral(std::unique_ptr<Literal>* literal,
|
||||
const Shape& shape);
|
||||
bool ParseLiteral(Literal* literal, const Shape& shape);
|
||||
bool ParseTupleLiteral(Literal* literal, const Shape& shape);
|
||||
bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
|
||||
bool ParseDenseLiteral(Literal* literal, const Shape& shape);
|
||||
bool ParseSparseLiteral(Literal* literal, const Shape& shape);
|
||||
template <typename LiteralNativeT>
|
||||
bool ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
|
||||
const Shape& shape);
|
||||
bool ParseSparseLiteralHelper(Literal* literal, const Shape& shape);
|
||||
|
||||
// Sets the sub-value of literal at the given index to the given value. The
|
||||
// literal's shape must have the default layout.
|
||||
@ -577,7 +574,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kConstant: {
|
||||
std::unique_ptr<Literal> literal;
|
||||
Literal literal;
|
||||
if (!ParseToken(TokKind::kLparen,
|
||||
"expects '(' before constant literal") ||
|
||||
!ParseLiteral(&literal, shape) ||
|
||||
@ -1810,8 +1807,7 @@ bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) {
|
||||
// literal
|
||||
// ::= tuple
|
||||
// ::= non_tuple
|
||||
bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
|
||||
const Shape& shape) {
|
||||
bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) {
|
||||
return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape)
|
||||
: ParseNonTupleLiteral(literal, shape);
|
||||
}
|
||||
@ -1821,8 +1817,7 @@ bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
|
||||
// literal_list
|
||||
// ::= /*empty*/
|
||||
// ::= literal (',' literal)*
|
||||
bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
|
||||
const Shape& shape) {
|
||||
bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) {
|
||||
if (!EatShapeAndCheckCompatible(shape)) {
|
||||
return TokenError(StrCat("expects tuple constant in shape ",
|
||||
ShapeUtil::HumanString(shape)));
|
||||
@ -1830,8 +1825,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
|
||||
if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
|
||||
return false;
|
||||
}
|
||||
std::vector<std::unique_ptr<Literal>> elements(
|
||||
ShapeUtil::TupleElementCount(shape));
|
||||
std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape));
|
||||
|
||||
if (lexer_.GetKind() == TokKind::kRparen) {
|
||||
// empty
|
||||
@ -1857,8 +1851,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
|
||||
// ::= rank01
|
||||
// ::= rank2345
|
||||
// rank2345 ::= shape sparse_or_nested_array
|
||||
bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
|
||||
const Shape& shape) {
|
||||
bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
|
||||
if (LayoutUtil::IsSparseArray(shape)) {
|
||||
return ParseSparseLiteral(literal, shape);
|
||||
}
|
||||
@ -1867,8 +1860,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
|
||||
return ParseDenseLiteral(literal, shape);
|
||||
}
|
||||
|
||||
bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
|
||||
const Shape& shape) {
|
||||
bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) {
|
||||
const tensorflow::int64 rank = ShapeUtil::Rank(shape);
|
||||
if (rank > 1 && !EatShapeAndCheckCompatible(shape)) {
|
||||
return false;
|
||||
@ -1962,7 +1954,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
|
||||
// TODO(congliu): bool type literals with rank >= 1 are actually
|
||||
// printed in a compact form instead of "true" or "false". Fix that.
|
||||
if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true,
|
||||
linear_index++, literal->get())) {
|
||||
linear_index++, literal)) {
|
||||
return false;
|
||||
}
|
||||
lexer_.Lex();
|
||||
@ -1973,7 +1965,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
|
||||
return Error(loc, StrCat("expects integer for primitive type: ",
|
||||
PrimitiveType_Name(shape.element_type())));
|
||||
}
|
||||
if (!SetValueInLiteral(value, linear_index++, literal->get())) {
|
||||
if (!SetValueInLiteral(value, linear_index++, literal)) {
|
||||
return false;
|
||||
}
|
||||
} else if (primitive_util::IsFloatingPointType(shape.element_type())) {
|
||||
@ -1984,7 +1976,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
|
||||
loc, StrCat("expect floating point value for primitive type: ",
|
||||
PrimitiveType_Name(shape.element_type())));
|
||||
}
|
||||
if (!SetValueInLiteral(value, linear_index++, literal->get())) {
|
||||
if (!SetValueInLiteral(value, linear_index++, literal)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
@ -1996,12 +1988,11 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
|
||||
} // end of switch
|
||||
} while (nest_level > 0);
|
||||
|
||||
*literal = (*literal)->Relayout(shape.layout());
|
||||
*literal = literal->Relayout(shape.layout());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
|
||||
const Shape& shape) {
|
||||
bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) {
|
||||
if (!EatShapeAndCheckCompatible(shape)) {
|
||||
return false;
|
||||
}
|
||||
@ -2041,13 +2032,12 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
|
||||
}
|
||||
|
||||
template <typename LiteralNativeT>
|
||||
bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
|
||||
const Shape& shape) {
|
||||
bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) {
|
||||
std::vector<tensorflow::int64> index;
|
||||
|
||||
tensorflow::int64 rank = ShapeUtil::Rank(shape);
|
||||
|
||||
*literal = absl::make_unique<Literal>(shape);
|
||||
*literal = Literal(shape);
|
||||
|
||||
if (!ParseToken(TokKind::kLbrace,
|
||||
"expects '{' at the beginning of a sparse literal")) {
|
||||
@ -2121,7 +2111,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((*literal)->sparse_element_count() + 1 ==
|
||||
if (literal->sparse_element_count() + 1 ==
|
||||
LayoutUtil::MaxSparseElements(shape.layout())) {
|
||||
return Error(
|
||||
lexer_.GetLoc(),
|
||||
@ -2129,10 +2119,10 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
|
||||
ShapeUtil::HumanStringWithLayout(shape)));
|
||||
}
|
||||
|
||||
(*literal)->AppendSparseElement(index, value);
|
||||
literal->AppendSparseElement(index, value);
|
||||
}
|
||||
|
||||
(*literal)->SortSparseElements();
|
||||
literal->SortSparseElements();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -118,16 +118,16 @@ StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
|
||||
}
|
||||
|
||||
StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
|
||||
const absl::Span<const std::unique_ptr<Literal>> literals) {
|
||||
const absl::Span<const Literal> literals) {
|
||||
std::vector<const Literal*> literal_pointers;
|
||||
literal_pointers.reserve(literals.size());
|
||||
for (const auto& literal : literals) {
|
||||
literal_pointers.push_back(literal.get());
|
||||
literal_pointers.push_back(&literal);
|
||||
}
|
||||
return TransferLiteralsToDevice(literal_pointers);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice(
|
||||
StatusOr<Literal> HloRunner::TransferLiteralFromDevice(
|
||||
const ShapedBuffer& buffer) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto stream, backend().BorrowStream(backend().default_stream_executor()));
|
||||
@ -135,7 +135,7 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice(
|
||||
buffer);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
|
||||
StatusOr<Literal> HloRunner::Execute(
|
||||
std::unique_ptr<HloModule> module,
|
||||
const absl::Span<const Literal* const> arguments, bool run_hlo_passes,
|
||||
ExecutionProfile* profile) {
|
||||
@ -150,15 +150,15 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
|
||||
return TransferLiteralFromDevice(result);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
|
||||
std::unique_ptr<HloModule> module,
|
||||
const absl::Span<const std::unique_ptr<Literal>> arguments,
|
||||
bool run_hlo_passes, ExecutionProfile* profile) {
|
||||
StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
|
||||
const absl::Span<const Literal> arguments,
|
||||
bool run_hlo_passes,
|
||||
ExecutionProfile* profile) {
|
||||
// Construct a vector of plain pointers for the arguments.
|
||||
std::vector<const Literal*> argument_pointers;
|
||||
argument_pointers.reserve(arguments.size());
|
||||
for (const auto& argument : arguments) {
|
||||
argument_pointers.push_back(argument.get());
|
||||
argument_pointers.push_back(&argument);
|
||||
}
|
||||
return Execute(
|
||||
/*module=*/std::move(module),
|
||||
@ -204,7 +204,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
|
||||
/*profile=*/profile);
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
|
||||
StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
|
||||
std::unique_ptr<HloModule> module,
|
||||
const ReplicatedExecuteOptions& options) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -290,9 +290,9 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
|
||||
VLOG(1) << "Starting outfeed on device " << device;
|
||||
for (int64 step = 1;
|
||||
options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
|
||||
auto literal = absl::make_unique<Literal>();
|
||||
Literal literal;
|
||||
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
|
||||
executor, options.outfeed_shape, literal.get()));
|
||||
executor, options.outfeed_shape, &literal));
|
||||
if (options.outfeed_values != nullptr) {
|
||||
options.outfeed_values->push_back(std::move(literal));
|
||||
}
|
||||
@ -310,10 +310,10 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
|
||||
argument_buffer_slices));
|
||||
LOG(INFO) << "Replicated execution terminated";
|
||||
|
||||
std::vector<std::unique_ptr<Literal>> exec_results;
|
||||
std::vector<Literal> exec_results;
|
||||
for (int64 i = 0; i < options.num_replicas; ++i) {
|
||||
TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone());
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
|
||||
TF_ASSIGN_OR_RETURN(Literal literal,
|
||||
backend().transfer_manager()->TransferLiteralFromDevice(
|
||||
streams[i].get(), results[i]));
|
||||
exec_results.push_back(std::move(literal));
|
||||
|
@ -72,7 +72,7 @@ class HloRunner {
|
||||
|
||||
// A pointer to a vector where the outfeed values will be stored. If
|
||||
// nullptr, the values will be read and discarded.
|
||||
std::vector<std::unique_ptr<Literal>>* outfeed_values = nullptr;
|
||||
std::vector<Literal>* outfeed_values = nullptr;
|
||||
|
||||
// Whether the HLO passes should be run on the input module. Usually
|
||||
// saved modules are coming from after the HLO pass pipeline, so triggering
|
||||
@ -106,24 +106,23 @@ class HloRunner {
|
||||
StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
|
||||
const absl::Span<const Literal* const> literals);
|
||||
StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
|
||||
const absl::Span<const std::unique_ptr<Literal>> literals);
|
||||
StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
|
||||
const ShapedBuffer& buffer);
|
||||
const absl::Span<const Literal> literals);
|
||||
StatusOr<Literal> TransferLiteralFromDevice(const ShapedBuffer& buffer);
|
||||
|
||||
// Executes the given module with given literals as input and returns the
|
||||
// result as a Literal.
|
||||
//
|
||||
// If run_hlo_passes is false, the module will be executed without Hlo
|
||||
// optimization.
|
||||
StatusOr<std::unique_ptr<Literal>> Execute(
|
||||
std::unique_ptr<HloModule> module,
|
||||
const absl::Span<const Literal* const> arguments,
|
||||
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
|
||||
StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
|
||||
const absl::Span<const Literal* const> arguments,
|
||||
bool run_hlo_passes = true,
|
||||
ExecutionProfile* profile = nullptr);
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> Execute(
|
||||
std::unique_ptr<HloModule> module,
|
||||
const absl::Span<const std::unique_ptr<Literal>> arguments,
|
||||
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
|
||||
StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
|
||||
const absl::Span<const Literal> arguments,
|
||||
bool run_hlo_passes = true,
|
||||
ExecutionProfile* profile = nullptr);
|
||||
|
||||
// As Execute(), but accepts and returns device buffers instead of host
|
||||
// buffers.
|
||||
@ -140,7 +139,7 @@ class HloRunner {
|
||||
// Executes a given HLO module into a set of replicas, and returns a map
|
||||
// with the replica number as key, and the corresponding returned literal as
|
||||
// value.
|
||||
StatusOr<std::vector<std::unique_ptr<Literal>>> ExecuteReplicated(
|
||||
StatusOr<std::vector<Literal>> ExecuteReplicated(
|
||||
std::unique_ptr<HloModule> module,
|
||||
const ReplicatedExecuteOptions& options);
|
||||
|
||||
|
@ -290,8 +290,8 @@ TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) {
|
||||
padding_config.add_dimensions()->set_interior_padding(-1);
|
||||
builder.AddInstruction(HloInstruction::CreatePad(
|
||||
ShapeUtil::MakeShape(F32, {100}), param,
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(F32).CloneToUnique())),
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
|
||||
padding_config));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
@ -314,8 +314,8 @@ TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) {
|
||||
padding_config.add_dimensions()->set_interior_padding(-1);
|
||||
builder.AddInstruction(HloInstruction::CreatePad(
|
||||
ShapeUtil::MakeShape(F32, {100}), param,
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(F32).CloneToUnique())),
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())),
|
||||
padding_config));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
|
@ -918,7 +918,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
|
||||
// inner_broadcast_result is the Broadcast'(Const0) bit in
|
||||
// BinaryOp(Broadcast'(Const0), Const1)
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<Literal> inner_broadcast_result,
|
||||
Literal inner_broadcast_result,
|
||||
broadcast_const_operand->literal().Broadcast(
|
||||
scalar_indexed_const->source()->shape(), new_inner_broadcast_dims));
|
||||
|
||||
@ -928,12 +928,12 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
literal_for_new_source,
|
||||
TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
|
||||
opcode, scalar_indexed_const->literal(), *inner_broadcast_result)));
|
||||
opcode, scalar_indexed_const->literal(), inner_broadcast_result)));
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
literal_for_new_source,
|
||||
TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
|
||||
opcode, *inner_broadcast_result, scalar_indexed_const->literal())));
|
||||
opcode, inner_broadcast_result, scalar_indexed_const->literal())));
|
||||
}
|
||||
|
||||
ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
|
||||
|
@ -347,21 +347,19 @@ class IndexedArrayAnalysis {
|
||||
}
|
||||
}
|
||||
|
||||
Literal* TakeOwnership(std::unique_ptr<Literal> literal) {
|
||||
Literal* TakeOwnership(Literal literal) {
|
||||
owned_literals_.push_back(std::move(literal));
|
||||
return owned_literals_.back().get();
|
||||
return &owned_literals_.back();
|
||||
}
|
||||
|
||||
StatusOr<Literal*> TakeOwnership(
|
||||
StatusOr<std::unique_ptr<Literal>> literal_or_error) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
|
||||
std::move(literal_or_error));
|
||||
StatusOr<Literal*> TakeOwnership(StatusOr<Literal> literal_or_error) {
|
||||
TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error));
|
||||
owned_literals_.push_back(std::move(literal));
|
||||
return owned_literals_.back().get();
|
||||
return &owned_literals_.back();
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<Array>> owned_tensors_;
|
||||
std::vector<std::unique_ptr<Literal>> owned_literals_;
|
||||
std::vector<Literal> owned_literals_;
|
||||
tensorflow::gtl::FlatMap<const HloInstruction*, Array*> cache_;
|
||||
};
|
||||
|
||||
|
@ -71,7 +71,7 @@ TEST_F(InlinerTest, MapMax) {
|
||||
// Verify execution on CPU.
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
auto expected = LiteralUtil::CreateR1<float>({4, 3, 3, 4});
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
|
||||
}
|
||||
|
||||
// Test that `constant` function is changed to `broadcast`.
|
||||
@ -105,7 +105,7 @@ TEST_F(InlinerTest, MapConstant) {
|
||||
// Verify execution on CPU.
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
auto expected = LiteralUtil::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
|
||||
}
|
||||
|
||||
TEST_F(InlinerTest, MapSubtractOppositeOrder) {
|
||||
@ -143,7 +143,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
|
||||
// Verify execution on CPU.
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
auto expected = LiteralUtil::CreateR1<float>({3, 1, -1, -3});
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
|
||||
}
|
||||
|
||||
|
||||
|
@ -73,30 +73,29 @@ StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
|
||||
|
||||
// Transform the ShapedBuffer arguments into literals which the evaluator
|
||||
// consumes.
|
||||
std::vector<std::unique_ptr<Literal>> arg_literals;
|
||||
std::vector<Literal> arg_literals;
|
||||
for (int64 p = 0; p < computation->num_parameters(); ++p) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> arg_literal,
|
||||
TF_ASSIGN_OR_RETURN(Literal arg_literal,
|
||||
transfer_manager->TransferLiteralFromDevice(
|
||||
run_options->stream(), *arguments[p]));
|
||||
arg_literals.push_back(std::move(arg_literal));
|
||||
}
|
||||
|
||||
// Execute the graph using the HloEvaluator.
|
||||
std::unique_ptr<Literal> result_literal;
|
||||
Literal result_literal;
|
||||
{
|
||||
tensorflow::mutex_lock lock(evaluator_lock_);
|
||||
TF_ASSIGN_OR_RETURN(result_literal,
|
||||
evaluator_->Evaluate<std::unique_ptr<Literal>>(
|
||||
*computation, arg_literals));
|
||||
TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate<Literal>(
|
||||
*computation, arg_literals));
|
||||
}
|
||||
|
||||
// Transform the result literal back into a ShapedBuffer.
|
||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
|
||||
transfer_manager->AllocateScopedShapedBuffer(
|
||||
result_literal->shape(), run_options->allocator(),
|
||||
result_literal.shape(), run_options->allocator(),
|
||||
executor->device_ordinal()));
|
||||
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
|
||||
run_options->stream(), *result_literal, result));
|
||||
run_options->stream(), result_literal, result));
|
||||
|
||||
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
|
||||
|
||||
|
@ -145,7 +145,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
|
||||
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
|
||||
auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
|
||||
{{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
|
||||
Shape ashape = constant_literal1->shape();
|
||||
Shape ashape = constant_literal1.shape();
|
||||
|
||||
auto constant1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(constant_literal1)));
|
||||
|
@ -68,9 +68,9 @@ Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
|
||||
module->clear_arguments();
|
||||
for (const ShapedBuffer* argument : arguments) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<Literal> literal,
|
||||
Literal literal,
|
||||
transfer_manager->TransferLiteralFromDevice(stream, *argument));
|
||||
*module->add_arguments() = literal->ToProto();
|
||||
*module->add_arguments() = literal.ToProto();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -80,9 +80,9 @@ Status RecordResult(const ShapedBuffer& result, se::Stream* stream,
|
||||
TransferManager* transfer_manager, HloSnapshot* module) {
|
||||
module->clear_result();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<Literal> literal,
|
||||
Literal literal,
|
||||
transfer_manager->TransferLiteralFromDevice(stream, result));
|
||||
*module->mutable_result() = literal->ToProto();
|
||||
*module->mutable_result() = literal.ToProto();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -928,16 +928,15 @@ Status Service::TransferToClient(const TransferToClientRequest* arg,
|
||||
shaped_buffer->device_ordinal()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<Literal> result_literal,
|
||||
Literal result_literal,
|
||||
execute_backend_->transfer_manager()->TransferLiteralFromDevice(
|
||||
stream.get(), *shaped_buffer));
|
||||
|
||||
if (LayoutUtil::LayoutsInShapesEqual(*return_shape,
|
||||
result_literal->shape())) {
|
||||
*result->mutable_literal() = result_literal->ToProto();
|
||||
if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal.shape())) {
|
||||
*result->mutable_literal() = result_literal.ToProto();
|
||||
} else {
|
||||
*result->mutable_literal() =
|
||||
result_literal->Relayout(*return_shape)->ToProto();
|
||||
result_literal.Relayout(*return_shape).ToProto();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -959,9 +958,9 @@ std::unique_ptr<ShapedBuffer> CloneShapedBufferOnDevice(
|
||||
|
||||
Status Service::TransferToServer(const TransferToServerRequest* arg,
|
||||
TransferToServerResponse* result) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
|
||||
TF_ASSIGN_OR_RETURN(Literal literal,
|
||||
Literal::CreateFromProto(arg->literal()));
|
||||
const Shape& shape = literal->shape();
|
||||
const Shape& shape = literal.shape();
|
||||
|
||||
std::vector<se::StreamExecutor*> replicas;
|
||||
if (arg->has_device_handle()) {
|
||||
@ -983,7 +982,7 @@ Status Service::TransferToServer(const TransferToServerRequest* arg,
|
||||
TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor));
|
||||
TF_RETURN_IF_ERROR(
|
||||
execute_backend_->transfer_manager()->TransferLiteralToDevice(
|
||||
stream.get(), *literal, shaped_buffer));
|
||||
stream.get(), literal, shaped_buffer));
|
||||
replicated_buffers.emplace_back(std::move(shaped_buffer));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(*result->mutable_data(),
|
||||
@ -1018,10 +1017,10 @@ Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
|
||||
executor = replicas[arg->replica_id()];
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
|
||||
TF_ASSIGN_OR_RETURN(Literal literal,
|
||||
Literal::CreateFromProto(arg->literal()));
|
||||
return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
|
||||
executor, *literal);
|
||||
return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor,
|
||||
literal);
|
||||
}
|
||||
|
||||
Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
|
||||
@ -1049,8 +1048,8 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
|
||||
executor, arg->shape_with_layout(), *literal));
|
||||
*result->mutable_literal() = literal->ToProto();
|
||||
executor, arg->shape_with_layout(), literal));
|
||||
*result->mutable_literal() = literal.ToProto();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1085,18 +1084,17 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
|
||||
HloModule::CreateFromProto(arg->computation(), config));
|
||||
|
||||
HloEvaluator evaluator;
|
||||
TF_ASSIGN_OR_RETURN(auto result_literal,
|
||||
evaluator.Evaluate<std::unique_ptr<Literal>>(
|
||||
*module, /*arg_literals=*/{}));
|
||||
TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate<Literal>(
|
||||
*module, /*arg_literals=*/{}));
|
||||
|
||||
// Since the result layout is non-effective to the Evaluator results, explicit
|
||||
// relayout here.
|
||||
//
|
||||
// TODO(b/77824332): Make HloEvaluator take care of the re-layout.
|
||||
if (arg->has_output_layout()) {
|
||||
result_literal = result_literal->Relayout(arg->output_layout());
|
||||
result_literal = result_literal.Relayout(arg->output_layout());
|
||||
}
|
||||
*result->mutable_literal() = result_literal->ToProto();
|
||||
*result->mutable_literal() = result_literal.ToProto();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -42,9 +42,9 @@ TransferManager::GetPlatformTransferManagers() {
|
||||
return r;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
|
||||
StatusOr<Literal> TransferManager::TransferLiteralFromDevice(
|
||||
se::Stream* stream, const ShapedBuffer& device_buffer) {
|
||||
StatusOr<std::unique_ptr<Literal>> ret;
|
||||
StatusOr<Literal> ret;
|
||||
|
||||
se::Stream* substream = stream->GetOrCreateSubStream();
|
||||
substream->ThenWaitFor(stream);
|
||||
@ -63,7 +63,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
return absl::make_unique<Literal>(std::move(literal));
|
||||
return std::move(literal);
|
||||
}
|
||||
|
||||
Status TransferManager::TransferLiteralFromDevice(
|
||||
@ -99,10 +99,10 @@ Status TransferManager::TransferLiteralToDevice(
|
||||
return substream->BlockHostUntilDone();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
|
||||
StatusOr<Literal> TransferManager::TransferArrayFromDevice(
|
||||
se::Stream* stream, const Shape& shape,
|
||||
const se::DeviceMemoryBase& source) {
|
||||
StatusOr<std::unique_ptr<Literal>> ret;
|
||||
StatusOr<Literal> ret;
|
||||
// Implement the synchronous version by waiting on the asynchronous version.
|
||||
// Use a substream so that if we are called from a HostCallback we don't
|
||||
// deadlock.
|
||||
@ -122,7 +122,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
return absl::make_unique<Literal>(std::move(literal));
|
||||
return std::move(literal);
|
||||
}
|
||||
|
||||
Status TransferManager::TransferArrayToDevice(
|
||||
|
@ -57,7 +57,7 @@ class TransferManager {
|
||||
// without waiting for any other operation on a stream to complete.
|
||||
//
|
||||
// This function should be avoided in favor of the asynchronous version below.
|
||||
virtual StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
|
||||
virtual StatusOr<Literal> TransferLiteralFromDevice(
|
||||
se::Stream* stream, const ShapedBuffer& device_buffer);
|
||||
virtual Status TransferLiteralFromDevice(
|
||||
se::Stream* stream, const ShapedBuffer& device_buffer,
|
||||
@ -113,9 +113,9 @@ class TransferManager {
|
||||
Status TransferArrayToDeviceAsync(se::Stream* stream,
|
||||
const LiteralSlice& literal,
|
||||
const se::DeviceMemoryBase& dest);
|
||||
StatusOr<std::unique_ptr<Literal>> TransferArrayFromDevice(
|
||||
se::Stream* stream, const Shape& shape,
|
||||
const se::DeviceMemoryBase& source);
|
||||
StatusOr<Literal> TransferArrayFromDevice(se::Stream* stream,
|
||||
const Shape& shape,
|
||||
const se::DeviceMemoryBase& source);
|
||||
|
||||
// Transfers the given literal into the Infeed interface of the device,
|
||||
// using the given executor.
|
||||
|
@ -555,10 +555,10 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) {
|
||||
// Construct a tuple constant and kCopy it. Verify the points-to set of the
|
||||
// copy correctly correctly points into the nested elements of the constant.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto tuple_constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::MakeTuple(
|
||||
{LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
|
||||
LiteralUtil::CreateR1<float>({2.0, 42}).get()})));
|
||||
Literal elements[] = {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
|
||||
LiteralUtil::CreateR1<float>({2.0, 42})};
|
||||
auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
|
||||
auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
tuple_constant->shape(), HloOpcode::kCopy, tuple_constant));
|
||||
|
||||
|
@ -183,8 +183,7 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
|
||||
HloEvaluator evaluator(/*max_loop_iterations=*/0);
|
||||
auto* while_init = while_op->mutable_operand(0);
|
||||
auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx);
|
||||
StatusOr<std::unique_ptr<Literal>> indvar_init_result =
|
||||
evaluator.Evaluate(indvar_init);
|
||||
StatusOr<Literal> indvar_init_result = evaluator.Evaluate(indvar_init);
|
||||
if (!indvar_init_result.ok()) {
|
||||
VLOG(2) << "Couldn't evaluate induction variable init: "
|
||||
<< indvar_init_result.status();
|
||||
@ -197,31 +196,27 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
|
||||
auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
|
||||
|
||||
// The initial value of the induction variable.
|
||||
std::unique_ptr<Literal> indvar_iter_val =
|
||||
std::move(indvar_init_result).ValueOrDie();
|
||||
Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie();
|
||||
for (int64 trip_count = 0; trip_count != max_value_returned + 1;
|
||||
++trip_count) {
|
||||
auto* while_cond = while_op->while_condition();
|
||||
auto* while_cond_root = while_cond->root_instruction();
|
||||
auto* while_cond_indvar = NonConstantOperand(while_cond_root);
|
||||
StatusOr<std::unique_ptr<Literal>> result =
|
||||
evaluator.EvaluateWithSubstitutions(
|
||||
while_cond_root, {{while_cond_indvar, indvar_iter_val.get()}});
|
||||
StatusOr<Literal> result = evaluator.EvaluateWithSubstitutions(
|
||||
while_cond_root, {{while_cond_indvar, &indvar_iter_val}});
|
||||
if (!result.ok()) {
|
||||
VLOG(2) << "Couldn't evaluate while cond: " << result.status();
|
||||
return nullopt;
|
||||
}
|
||||
if (result.ValueOrDie()->data<bool>() == absl::Span<const bool>{false}) {
|
||||
if (result.ValueOrDie().data<bool>() == absl::Span<const bool>{false}) {
|
||||
VLOG(2) << "Loop has static trip count of " << trip_count;
|
||||
return trip_count;
|
||||
}
|
||||
|
||||
// Calculate the value of the induction variable after one iteration of the
|
||||
// loop, and check whether the while condition is true with this new value.
|
||||
StatusOr<std::unique_ptr<Literal>> indvar_next_result =
|
||||
evaluator.EvaluateWithSubstitutions(
|
||||
while_body_indvar_update,
|
||||
{{while_body_indvar, indvar_iter_val.get()}});
|
||||
StatusOr<Literal> indvar_next_result = evaluator.EvaluateWithSubstitutions(
|
||||
while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}});
|
||||
if (!indvar_next_result.ok()) {
|
||||
VLOG(2) << "Couldn't evaluate induction variable update: "
|
||||
<< indvar_next_result.status();
|
||||
|
@ -41,7 +41,6 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
|
||||
class ArrayElementwiseOpTest : public ClientLibraryTestBase {
|
||||
public:
|
||||
ErrorSpec error_spec_{0.0001, 0.0001};
|
||||
@ -227,10 +226,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
|
||||
0x8000000000000000LL,
|
||||
0x8000000000000000LL,
|
||||
1};
|
||||
std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
|
||||
auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
|
||||
Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
|
||||
auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
|
||||
std::unique_ptr<GlobalData> lhs_data =
|
||||
client_->TransferToServer(*lhs_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
|
||||
|
||||
std::vector<uint64> rhs{1,
|
||||
0x7FFFFFFFFFFFFFFLL,
|
||||
@ -241,10 +240,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
|
||||
0,
|
||||
1,
|
||||
0x8000000000000000LL};
|
||||
std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
|
||||
auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
|
||||
Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
|
||||
auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
|
||||
std::unique_ptr<GlobalData> rhs_data =
|
||||
client_->TransferToServer(*rhs_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
|
||||
|
||||
Add(lhs_param, rhs_param);
|
||||
|
||||
@ -267,10 +266,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
|
||||
1,
|
||||
0,
|
||||
-1};
|
||||
std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
|
||||
auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
|
||||
Literal lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
|
||||
auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
|
||||
std::unique_ptr<GlobalData> lhs_data =
|
||||
client_->TransferToServer(*lhs_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
|
||||
|
||||
std::vector<int64> rhs{-1,
|
||||
0,
|
||||
@ -280,10 +279,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
|
||||
0x7FFFFFFFFFFFFFFLL,
|
||||
0x7FFFFFFFFFFFFFFFLL,
|
||||
0x7FFFFFFFFFFFFFFFLL};
|
||||
std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
|
||||
auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
|
||||
Literal rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
|
||||
auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
|
||||
std::unique_ptr<GlobalData> rhs_data =
|
||||
client_->TransferToServer(*rhs_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
|
||||
|
||||
Sub(lhs_param, rhs_param);
|
||||
|
||||
@ -299,16 +298,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) {
|
||||
XlaBuilder b(TestName());
|
||||
|
||||
std::vector<uint64> lhs{static_cast<uint64>(0x8000000000000000ULL)};
|
||||
std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
|
||||
auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
|
||||
Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
|
||||
auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
|
||||
|
||||
std::vector<uint64> rhs{static_cast<uint64>(0x7FFFFFFFFFFFFFFFULL)};
|
||||
std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
|
||||
auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
|
||||
Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
|
||||
auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
|
||||
|
||||
Lt(lhs_param, rhs_param);
|
||||
|
||||
ComputeAndCompare(&b, {std::move(*lhs_literal), std::move(*rhs_literal)});
|
||||
ComputeAndCompare(&b, {std::move(lhs_literal), std::move(rhs_literal)});
|
||||
}
|
||||
|
||||
TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
|
||||
@ -321,16 +320,16 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
|
||||
b_values.push_back(2 * i / static_cast<float>(count + 2));
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({a_values});
|
||||
Literal a_literal = LiteralUtil::CreateR1<float>({a_values});
|
||||
std::unique_ptr<GlobalData> a_data =
|
||||
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(a_literal).ConsumeValueOrDie();
|
||||
auto a_constant = ConstantR1<float>(&builder, a_values);
|
||||
auto a_param = Parameter(&builder, 0, a_literal->shape(), "a_param");
|
||||
auto a_param = Parameter(&builder, 0, a_literal.shape(), "a_param");
|
||||
|
||||
std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR1<float>({b_values});
|
||||
Literal b_literal = LiteralUtil::CreateR1<float>({b_values});
|
||||
std::unique_ptr<GlobalData> b_data =
|
||||
client_->TransferToServer(*b_literal).ConsumeValueOrDie();
|
||||
auto b_constant = Parameter(&builder, 1, a_literal->shape(), "b_param");
|
||||
client_->TransferToServer(b_literal).ConsumeValueOrDie();
|
||||
auto b_constant = Parameter(&builder, 1, a_literal.shape(), "b_param");
|
||||
auto b_param = ConstantR1<float>(&builder, b_values);
|
||||
|
||||
auto sum1 = Add(a_constant, b_constant);
|
||||
@ -1422,12 +1421,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
|
||||
std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
|
||||
std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
|
||||
|
||||
std::unique_ptr<Literal> param_literal = LiteralUtil::CreateR1<float>(values);
|
||||
Literal param_literal = LiteralUtil::CreateR1<float>(values);
|
||||
std::unique_ptr<GlobalData> param_data =
|
||||
client_->TransferToServer(*param_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(param_literal).ConsumeValueOrDie();
|
||||
|
||||
auto sum = ConstantR0<float>(&b, 0.0f);
|
||||
auto param = Parameter(&b, 0, param_literal->shape(), "param");
|
||||
auto param = Parameter(&b, 0, param_literal.shape(), "param");
|
||||
for (float exponent : exponents) {
|
||||
sum = Add(sum, Pow(param, ConstantR0<float>(&b, exponent)));
|
||||
}
|
||||
@ -1450,14 +1449,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
|
||||
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
|
||||
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
|
||||
|
||||
std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
|
||||
Literal literal0 = LiteralUtil::CreateR1<float>(values0);
|
||||
std::unique_ptr<GlobalData> data0 =
|
||||
client_->TransferToServer(*literal0).ConsumeValueOrDie();
|
||||
std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
|
||||
client_->TransferToServer(literal0).ConsumeValueOrDie();
|
||||
Literal literal1 = LiteralUtil::CreateR1<float>(values1);
|
||||
std::unique_ptr<GlobalData> data1 =
|
||||
client_->TransferToServer(*literal1).ConsumeValueOrDie();
|
||||
auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
|
||||
auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
|
||||
client_->TransferToServer(literal1).ConsumeValueOrDie();
|
||||
auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
|
||||
auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
|
||||
Pow(Exp(param0), param1);
|
||||
|
||||
std::vector<float> expected(values0.size());
|
||||
@ -1475,14 +1474,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
|
||||
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f};
|
||||
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
|
||||
|
||||
std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
|
||||
Literal literal0 = LiteralUtil::CreateR1<float>(values0);
|
||||
std::unique_ptr<GlobalData> data0 =
|
||||
client_->TransferToServer(*literal0).ConsumeValueOrDie();
|
||||
std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
|
||||
client_->TransferToServer(literal0).ConsumeValueOrDie();
|
||||
Literal literal1 = LiteralUtil::CreateR1<float>(values1);
|
||||
std::unique_ptr<GlobalData> data1 =
|
||||
client_->TransferToServer(*literal1).ConsumeValueOrDie();
|
||||
auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
|
||||
auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
|
||||
client_->TransferToServer(literal1).ConsumeValueOrDie();
|
||||
auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
|
||||
auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
|
||||
Log(Pow(param0, param1));
|
||||
|
||||
std::vector<float> expected(values0.size());
|
||||
@ -1500,14 +1499,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
|
||||
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
|
||||
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
|
||||
|
||||
std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
|
||||
Literal literal0 = LiteralUtil::CreateR1<float>(values0);
|
||||
std::unique_ptr<GlobalData> data0 =
|
||||
client_->TransferToServer(*literal0).ConsumeValueOrDie();
|
||||
std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
|
||||
client_->TransferToServer(literal0).ConsumeValueOrDie();
|
||||
Literal literal1 = LiteralUtil::CreateR1<float>(values1);
|
||||
std::unique_ptr<GlobalData> data1 =
|
||||
client_->TransferToServer(*literal1).ConsumeValueOrDie();
|
||||
auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
|
||||
auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
|
||||
client_->TransferToServer(literal1).ConsumeValueOrDie();
|
||||
auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
|
||||
auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
|
||||
Mul(Exp(param0), Exp(param1));
|
||||
|
||||
std::vector<float> expected(values0.size());
|
||||
@ -1525,14 +1524,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
|
||||
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
|
||||
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
|
||||
|
||||
std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
|
||||
Literal literal0 = LiteralUtil::CreateR1<float>(values0);
|
||||
std::unique_ptr<GlobalData> data0 =
|
||||
client_->TransferToServer(*literal0).ConsumeValueOrDie();
|
||||
std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
|
||||
client_->TransferToServer(literal0).ConsumeValueOrDie();
|
||||
Literal literal1 = LiteralUtil::CreateR1<float>(values1);
|
||||
std::unique_ptr<GlobalData> data1 =
|
||||
client_->TransferToServer(*literal1).ConsumeValueOrDie();
|
||||
auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
|
||||
auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
|
||||
client_->TransferToServer(literal1).ConsumeValueOrDie();
|
||||
auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
|
||||
auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
|
||||
Div(param0, Exp(param1));
|
||||
|
||||
std::vector<float> expected(values0.size());
|
||||
@ -1551,20 +1550,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
|
||||
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
|
||||
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
|
||||
|
||||
std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
|
||||
Literal literal0 = LiteralUtil::CreateR1<float>(values0);
|
||||
std::unique_ptr<GlobalData> data0 =
|
||||
client_->TransferToServer(*literal0).ConsumeValueOrDie();
|
||||
client_->TransferToServer(literal0).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
|
||||
Literal literal1 = LiteralUtil::CreateR1<float>(values1);
|
||||
std::unique_ptr<GlobalData> data1 =
|
||||
client_->TransferToServer(*literal1).ConsumeValueOrDie();
|
||||
client_->TransferToServer(literal1).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
|
||||
Literal literal2 = LiteralUtil::CreateR1<float>(values2);
|
||||
std::unique_ptr<GlobalData> data2 =
|
||||
client_->TransferToServer(*literal2).ConsumeValueOrDie();
|
||||
auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
|
||||
auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
|
||||
auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
|
||||
client_->TransferToServer(literal2).ConsumeValueOrDie();
|
||||
auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
|
||||
auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
|
||||
auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
|
||||
Div(Div(param0, param1), param2);
|
||||
|
||||
std::vector<float> expected(values0.size());
|
||||
@ -1583,21 +1582,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
|
||||
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
|
||||
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
|
||||
|
||||
std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
|
||||
Literal literal0 = LiteralUtil::CreateR1<float>(values0);
|
||||
std::unique_ptr<GlobalData> data0 =
|
||||
client_->TransferToServer(*literal0).ConsumeValueOrDie();
|
||||
client_->TransferToServer(literal0).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
|
||||
Literal literal1 = LiteralUtil::CreateR1<float>(values1);
|
||||
std::unique_ptr<GlobalData> data1 =
|
||||
client_->TransferToServer(*literal1).ConsumeValueOrDie();
|
||||
client_->TransferToServer(literal1).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
|
||||
Literal literal2 = LiteralUtil::CreateR1<float>(values2);
|
||||
std::unique_ptr<GlobalData> data2 =
|
||||
client_->TransferToServer(*literal2).ConsumeValueOrDie();
|
||||
client_->TransferToServer(literal2).ConsumeValueOrDie();
|
||||
|
||||
auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
|
||||
auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
|
||||
auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
|
||||
auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
|
||||
auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
|
||||
auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
|
||||
Div(param0, Div(param1, param2));
|
||||
|
||||
std::vector<float> expected(values0.size());
|
||||
@ -1616,21 +1615,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
|
||||
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f};
|
||||
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f};
|
||||
|
||||
std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
|
||||
Literal literal0 = LiteralUtil::CreateR1<float>(values0);
|
||||
std::unique_ptr<GlobalData> data0 =
|
||||
client_->TransferToServer(*literal0).ConsumeValueOrDie();
|
||||
client_->TransferToServer(literal0).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
|
||||
Literal literal1 = LiteralUtil::CreateR1<float>(values1);
|
||||
std::unique_ptr<GlobalData> data1 =
|
||||
client_->TransferToServer(*literal1).ConsumeValueOrDie();
|
||||
client_->TransferToServer(literal1).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
|
||||
Literal literal2 = LiteralUtil::CreateR1<float>(values2);
|
||||
std::unique_ptr<GlobalData> data2 =
|
||||
client_->TransferToServer(*literal2).ConsumeValueOrDie();
|
||||
client_->TransferToServer(literal2).ConsumeValueOrDie();
|
||||
|
||||
auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
|
||||
auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
|
||||
auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
|
||||
auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
|
||||
auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
|
||||
auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
|
||||
Div(param0, Pow(param1, param2));
|
||||
|
||||
std::vector<float> expected(values0.size());
|
||||
@ -1650,26 +1649,26 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) {
|
||||
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
|
||||
std::vector<float> values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f};
|
||||
|
||||
std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
|
||||
Literal literal0 = LiteralUtil::CreateR1<float>(values0);
|
||||
std::unique_ptr<GlobalData> data0 =
|
||||
client_->TransferToServer(*literal0).ConsumeValueOrDie();
|
||||
client_->TransferToServer(literal0).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
|
||||
Literal literal1 = LiteralUtil::CreateR1<float>(values1);
|
||||
std::unique_ptr<GlobalData> data1 =
|
||||
client_->TransferToServer(*literal1).ConsumeValueOrDie();
|
||||
client_->TransferToServer(literal1).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
|
||||
Literal literal2 = LiteralUtil::CreateR1<float>(values2);
|
||||
std::unique_ptr<GlobalData> data2 =
|
||||
client_->TransferToServer(*literal2).ConsumeValueOrDie();
|
||||
client_->TransferToServer(literal2).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<Literal> literal3 = LiteralUtil::CreateR1<float>(values3);
|
||||
Literal literal3 = LiteralUtil::CreateR1<float>(values3);
|
||||
std::unique_ptr<GlobalData> data3 =
|
||||
client_->TransferToServer(*literal3).ConsumeValueOrDie();
|
||||
client_->TransferToServer(literal3).ConsumeValueOrDie();
|
||||
|
||||
auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
|
||||
auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
|
||||
auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
|
||||
auto param3 = Parameter(&b, 3, literal3->shape(), "param2");
|
||||
auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
|
||||
auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
|
||||
auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
|
||||
auto param3 = Parameter(&b, 3, literal3.shape(), "param2");
|
||||
Div(Div(param0, param1), Div(param2, param3));
|
||||
|
||||
std::vector<float> expected(values0.size());
|
||||
@ -2096,18 +2095,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
std::unique_ptr<Literal> param0_literal =
|
||||
Literal param0_literal =
|
||||
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
|
||||
std::unique_ptr<GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<Literal> param1_literal =
|
||||
Literal param1_literal =
|
||||
LiteralUtil::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
|
||||
std::unique_ptr<GlobalData> param1_data =
|
||||
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(param1_literal).ConsumeValueOrDie();
|
||||
|
||||
auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
|
||||
auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
|
||||
auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
|
||||
auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
|
||||
Add(p0, p1);
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f},
|
||||
@ -2118,18 +2117,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
std::unique_ptr<Literal> param0_literal =
|
||||
Literal param0_literal =
|
||||
LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
|
||||
std::unique_ptr<GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<Literal> param1_literal =
|
||||
Literal param1_literal =
|
||||
LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
|
||||
std::unique_ptr<GlobalData> param1_data =
|
||||
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(param1_literal).ConsumeValueOrDie();
|
||||
|
||||
auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
|
||||
auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
|
||||
auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
|
||||
auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
|
||||
Add(p0, p1);
|
||||
|
||||
Array3D<float> expected(0, 7, 0);
|
||||
@ -2140,13 +2139,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
std::unique_ptr<Literal> param0_literal =
|
||||
Literal param0_literal =
|
||||
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
|
||||
std::unique_ptr<GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
|
||||
|
||||
auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
|
||||
auto p = Parameter(&builder, 0, param0_literal->shape(), "param0");
|
||||
auto p = Parameter(&builder, 0, param0_literal.shape(), "param0");
|
||||
Add(a, p);
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f},
|
||||
@ -2206,9 +2205,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
|
||||
0.08, -1.24, -0.92, 0.49, 1.17, -0.45, -1.31, -1.44, -0.13, -1.31,
|
||||
-0.79, 1.41, 1.21, 1.05});
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto input_data,
|
||||
client_->TransferToServer(*input_literal));
|
||||
client_->TransferToServer(input_literal));
|
||||
|
||||
auto input = Parameter(&builder, 0, input_literal->shape(), "input");
|
||||
auto input = Parameter(&builder, 0, input_literal.shape(), "input");
|
||||
Tanh(input);
|
||||
|
||||
ComputeAndCompareR1<float>(
|
||||
@ -2239,7 +2238,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
|
||||
|
||||
// Just to help make sense of the scales here -- exp(89) saturates float32 and
|
||||
// exp(-10) is smaller than our error spec.
|
||||
std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1<float>(
|
||||
Literal input_literal = LiteralUtil::CreateR1<float>(
|
||||
{1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31,
|
||||
-1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5,
|
||||
-193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4,
|
||||
@ -2252,16 +2251,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
|
||||
78.3, 79.4, 80.5, 81.6, 82.7, 83.8, 84.9, 85.2, 86.3,
|
||||
86.4, 86.5, 87.6, 87.7, 87.8, 87.9});
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
|
||||
client_->TransferToServer(*input_literal));
|
||||
client_->TransferToServer(input_literal));
|
||||
|
||||
auto input = Parameter(&builder, 0, input_literal->shape(), "input");
|
||||
auto input = Parameter(&builder, 0, input_literal.shape(), "input");
|
||||
Exp(input);
|
||||
|
||||
std::vector<float> expected_result;
|
||||
int64 input_size = input_literal->shape().dimensions(0);
|
||||
int64 input_size = input_literal.shape().dimensions(0);
|
||||
expected_result.reserve(input_size);
|
||||
for (int64 i = 0; i < input_size; i++) {
|
||||
expected_result.push_back(std::exp(input_literal->Get<float>({i})));
|
||||
expected_result.push_back(std::exp(input_literal.Get<float>({i})));
|
||||
}
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
|
||||
@ -2273,7 +2272,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
|
||||
// implementation on XLA CPU.
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1<float>(
|
||||
Literal input_literal = LiteralUtil::CreateR1<float>(
|
||||
{-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198,
|
||||
-167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9,
|
||||
198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04,
|
||||
@ -2290,16 +2289,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
|
||||
1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33,
|
||||
1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35});
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
|
||||
client_->TransferToServer(*input_literal));
|
||||
client_->TransferToServer(input_literal));
|
||||
|
||||
auto input = Parameter(&builder, 0, input_literal->shape(), "input");
|
||||
auto input = Parameter(&builder, 0, input_literal.shape(), "input");
|
||||
Log(input);
|
||||
|
||||
std::vector<float> expected_result;
|
||||
int64 input_size = input_literal->shape().dimensions(0);
|
||||
int64 input_size = input_literal.shape().dimensions(0);
|
||||
expected_result.reserve(input_size);
|
||||
for (int64 i = 0; i < input_size; i++) {
|
||||
expected_result.push_back(std::log(input_literal->Get<float>({i})));
|
||||
expected_result.push_back(std::log(input_literal.Get<float>({i})));
|
||||
}
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
|
||||
@ -2465,10 +2464,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) {
|
||||
auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0});
|
||||
Tuple(&builder, {cmp_dim_0, cmp_dim_1});
|
||||
|
||||
auto expected = LiteralUtil::MakeTuple(
|
||||
{LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}).get(),
|
||||
LiteralUtil::CreateR2<bool>({{true, false}, {false, false}}).get()});
|
||||
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
|
||||
auto expected = LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}),
|
||||
LiteralUtil::CreateR2<bool>({{true, false}, {false, false}})});
|
||||
ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
|
||||
@ -2821,10 +2820,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
|
||||
std::iota(r1.begin(), r1.end(), 1.0);
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
std::unique_ptr<Literal> a_literal =
|
||||
LiteralUtil::CreateR4FromArray4DWithLayout(
|
||||
r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
|
||||
auto a = ConstantLiteral(&builder, *a_literal);
|
||||
Literal a_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
|
||||
r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
|
||||
auto a = ConstantLiteral(&builder, a_literal);
|
||||
auto b = ConstantR1<float>(&builder, r1);
|
||||
Add(a, b, {1});
|
||||
|
||||
@ -2886,11 +2884,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
|
||||
auto y_literal = LiteralUtil::CreateR1<float>({4, 5});
|
||||
auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
|
||||
auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
|
||||
auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
|
||||
auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
|
||||
|
||||
auto x = Parameter(&builder, 0, x_literal->shape(), "x");
|
||||
auto y = Parameter(&builder, 1, y_literal->shape(), "y");
|
||||
auto x = Parameter(&builder, 0, x_literal.shape(), "x");
|
||||
auto y = Parameter(&builder, 1, y_literal.shape(), "y");
|
||||
auto slice = Slice(x, {1}, {2}, {1});
|
||||
Sub(slice, y);
|
||||
|
||||
|
@ -63,7 +63,7 @@ class BatchNormalizationTest
|
||||
{5.0f, 4.4f}, // p2
|
||||
});
|
||||
input_array_.FillWithPZ(pz);
|
||||
input_literal_ = std::move(*LiteralUtil::CreateR4FromArray4D(input_array_));
|
||||
input_literal_ = LiteralUtil::CreateR4FromArray4D(input_array_);
|
||||
CHECK_EQ(kSamples, input_array_.planes());
|
||||
CHECK_EQ(kZ, input_array_.depth());
|
||||
CHECK_EQ(kY, input_array_.height());
|
||||
@ -242,14 +242,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) {
|
||||
BatchNormTraining(operand, scale, offset,
|
||||
/*epsilon=*/0.001, kFeatureIndex);
|
||||
|
||||
auto expected = LiteralUtil::MakeTuple(
|
||||
auto expected = LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
|
||||
{{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
|
||||
.get(),
|
||||
LiteralUtil::CreateR1<float>({4, 5}).get(),
|
||||
LiteralUtil::CreateR1<float>({5, 5}).get()});
|
||||
{{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}),
|
||||
LiteralUtil::CreateR1<float>({4, 5}),
|
||||
LiteralUtil::CreateR1<float>({5, 5})});
|
||||
|
||||
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
|
||||
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
|
||||
}
|
||||
|
||||
XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
|
||||
@ -267,14 +266,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
|
||||
BatchNormTraining(operand, scale, offset,
|
||||
/*epsilon=*/0.001, kFeatureIndex);
|
||||
|
||||
auto expected = LiteralUtil::MakeTuple(
|
||||
auto expected = LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
|
||||
{{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
|
||||
.get(),
|
||||
LiteralUtil::CreateR1<float>({4, 5}).get(),
|
||||
LiteralUtil::CreateR1<float>({5, 5}).get()});
|
||||
{{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}),
|
||||
LiteralUtil::CreateR1<float>({4, 5}),
|
||||
LiteralUtil::CreateR1<float>({5, 5})});
|
||||
|
||||
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
|
||||
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
|
||||
}
|
||||
|
||||
XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
|
||||
@ -298,13 +296,12 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
|
||||
BatchNormTraining(h0, h1, h2,
|
||||
/*epsilon=*/1, kFeatureIndex);
|
||||
|
||||
auto expected = LiteralUtil::MakeTuple(
|
||||
{LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
|
||||
.get(),
|
||||
LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
|
||||
LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
|
||||
auto expected = LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f)),
|
||||
LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)),
|
||||
LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f))});
|
||||
|
||||
ComputeAndCompareTuple(&builder, *expected,
|
||||
ComputeAndCompareTuple(&builder, expected,
|
||||
{operand.get(), scale.get(), offset.get()},
|
||||
ErrorSpec(0.1));
|
||||
}
|
||||
@ -331,14 +328,13 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) {
|
||||
BatchNormTraining(h0, h1, h2,
|
||||
/*epsilon=*/-100, kFeatureIndex);
|
||||
|
||||
auto expected = LiteralUtil::MakeTuple(
|
||||
auto expected = LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR3FromArray3D<float>(
|
||||
{{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
|
||||
.get(),
|
||||
LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
|
||||
LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
|
||||
{{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}),
|
||||
LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)),
|
||||
LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f))});
|
||||
|
||||
ComputeAndCompareTuple(&builder, *expected,
|
||||
ComputeAndCompareTuple(&builder, expected,
|
||||
{operand.get(), scale.get(), offset.get()},
|
||||
ErrorSpec(0.1));
|
||||
}
|
||||
@ -363,14 +359,13 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) {
|
||||
BatchNormGrad(operand, scale, mean, var, grad_output,
|
||||
/*epsilon=*/0.0, kFeatureIndex);
|
||||
|
||||
auto expected = LiteralUtil::MakeTuple(
|
||||
auto expected = LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
|
||||
{{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
|
||||
.get(),
|
||||
LiteralUtil::CreateR1<float>({0, 0}).get(),
|
||||
LiteralUtil::CreateR1<float>({16, 20}).get()});
|
||||
{{{1.f}, {1.f}}, {{3.f}, {3.f}}}}),
|
||||
LiteralUtil::CreateR1<float>({0, 0}),
|
||||
LiteralUtil::CreateR1<float>({16, 20})});
|
||||
|
||||
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
|
||||
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
|
||||
}
|
||||
|
||||
struct BatchNormTestParam {
|
||||
@ -522,22 +517,22 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
|
||||
auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
|
||||
|
||||
auto input_activations =
|
||||
Parameter(&builder, 0, input_literal->shape(), "input");
|
||||
Parameter(&builder, 0, input_literal.shape(), "input");
|
||||
auto scale_activations =
|
||||
Parameter(&builder, 1, scale_literal->shape(), "offset");
|
||||
Parameter(&builder, 1, scale_literal.shape(), "offset");
|
||||
auto offset_activations =
|
||||
Parameter(&builder, 2, offset_literal->shape(), "scale");
|
||||
Parameter(&builder, 2, offset_literal.shape(), "scale");
|
||||
|
||||
auto expected = LiteralUtil::MakeTuple(
|
||||
{expected_normalized.get(), LiteralUtil::CreateR1<float>(mean).get(),
|
||||
LiteralUtil::CreateR1<float>(var).get()});
|
||||
auto expected = LiteralUtil::MakeTupleFromSlices(
|
||||
{expected_normalized, LiteralUtil::CreateR1<float>(mean),
|
||||
LiteralUtil::CreateR1<float>(var)});
|
||||
|
||||
std::unique_ptr<GlobalData> input_data =
|
||||
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(input_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<GlobalData> scale_data =
|
||||
client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(scale_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<GlobalData> offset_data =
|
||||
client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(offset_literal).ConsumeValueOrDie();
|
||||
|
||||
BatchNormTraining(input_activations, scale_activations, offset_activations,
|
||||
epsilon, feature_index);
|
||||
@ -547,7 +542,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
|
||||
// testcase.
|
||||
execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
|
||||
ComputeAndCompareTuple(
|
||||
&builder, *expected,
|
||||
&builder, expected,
|
||||
{input_data.get(), scale_data.get(), offset_data.get()},
|
||||
ErrorSpec(0.01, 1));
|
||||
}
|
||||
@ -622,27 +617,27 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) {
|
||||
auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
|
||||
|
||||
auto input_activations =
|
||||
Parameter(&builder, 0, input_literal->shape(), "input");
|
||||
Parameter(&builder, 0, input_literal.shape(), "input");
|
||||
auto scale_activations =
|
||||
Parameter(&builder, 1, scale_literal->shape(), "offset");
|
||||
Parameter(&builder, 1, scale_literal.shape(), "offset");
|
||||
auto offset_activations =
|
||||
Parameter(&builder, 2, offset_literal->shape(), "scale");
|
||||
auto mean_activations = Parameter(&builder, 3, mean_literal->shape(), "mean");
|
||||
Parameter(&builder, 2, offset_literal.shape(), "scale");
|
||||
auto mean_activations = Parameter(&builder, 3, mean_literal.shape(), "mean");
|
||||
auto variance_activations =
|
||||
Parameter(&builder, 4, var_literal->shape(), "variance");
|
||||
Parameter(&builder, 4, var_literal.shape(), "variance");
|
||||
|
||||
Array4D<float> expected = normalized;
|
||||
|
||||
std::unique_ptr<GlobalData> input_data =
|
||||
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(input_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<GlobalData> scale_data =
|
||||
client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(scale_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<GlobalData> offset_data =
|
||||
client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(offset_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<GlobalData> mean_data =
|
||||
client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(mean_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<GlobalData> variance_data =
|
||||
client_->TransferToServer(*var_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(var_literal).ConsumeValueOrDie();
|
||||
|
||||
BatchNormInference(input_activations, scale_activations, offset_activations,
|
||||
mean_activations, variance_activations, epsilon,
|
||||
@ -811,40 +806,37 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
|
||||
auto grad_output_literal =
|
||||
LiteralUtil::CreateR4FromArray4D<float>(grad_output_array);
|
||||
|
||||
auto input_parameter =
|
||||
Parameter(&builder, 0, input_literal->shape(), "input");
|
||||
auto scale_parameter =
|
||||
Parameter(&builder, 1, scale_literal->shape(), "scale");
|
||||
auto mean_parameter = Parameter(&builder, 2, mean_literal->shape(), "mean");
|
||||
auto var_parameter = Parameter(&builder, 3, var_literal->shape(), "variance");
|
||||
auto input_parameter = Parameter(&builder, 0, input_literal.shape(), "input");
|
||||
auto scale_parameter = Parameter(&builder, 1, scale_literal.shape(), "scale");
|
||||
auto mean_parameter = Parameter(&builder, 2, mean_literal.shape(), "mean");
|
||||
auto var_parameter = Parameter(&builder, 3, var_literal.shape(), "variance");
|
||||
auto grad_output_parameter =
|
||||
Parameter(&builder, 4, grad_output_literal->shape(), "grad_output");
|
||||
Parameter(&builder, 4, grad_output_literal.shape(), "grad_output");
|
||||
|
||||
std::unique_ptr<GlobalData> input_data =
|
||||
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(input_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<GlobalData> scale_data =
|
||||
client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(scale_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<GlobalData> mean_data =
|
||||
client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(mean_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<GlobalData> var_data =
|
||||
client_->TransferToServer(*var_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(var_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<GlobalData> grad_output_data =
|
||||
client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(grad_output_literal).ConsumeValueOrDie();
|
||||
|
||||
BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter,
|
||||
grad_output_parameter, epsilon, feature_index);
|
||||
|
||||
auto expected =
|
||||
LiteralUtil::MakeTuple({expected_grad_activation.get(),
|
||||
LiteralUtil::CreateR1<float>(grad_scale).get(),
|
||||
LiteralUtil::CreateR1<float>(grad_offset).get()});
|
||||
auto expected = LiteralUtil::MakeTupleFromSlices(
|
||||
{expected_grad_activation, LiteralUtil::CreateR1<float>(grad_scale),
|
||||
LiteralUtil::CreateR1<float>(grad_offset)});
|
||||
|
||||
// Run all HLO passes during this test. In particular, ClientLibraryTestBase
|
||||
// disables constant folding, but we want it enabled for our zero-sized tensor
|
||||
// testcase.
|
||||
execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
|
||||
|
||||
ComputeAndCompareTuple(&builder, *expected,
|
||||
ComputeAndCompareTuple(&builder, expected,
|
||||
{input_data.get(), scale_data.get(), mean_data.get(),
|
||||
var_data.get(), grad_output_data.get()},
|
||||
ErrorSpec(0.01, 1));
|
||||
|
@ -95,22 +95,19 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
|
||||
|
||||
BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex);
|
||||
|
||||
auto expected = LiteralUtil::MakeTuple(
|
||||
auto expected = LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR4<bfloat16>(
|
||||
{{{{static_cast<bfloat16>(-1.6875f)},
|
||||
{static_cast<bfloat16>(-2.04f)}},
|
||||
{{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.66f)}}},
|
||||
{{{static_cast<bfloat16>(1.89f)}, {static_cast<bfloat16>(3.35f)}},
|
||||
{{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}})
|
||||
.get(),
|
||||
{{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}}),
|
||||
LiteralUtil::CreateR1<bfloat16>(
|
||||
{static_cast<bfloat16>(4), static_cast<bfloat16>(5)})
|
||||
.get(),
|
||||
{static_cast<bfloat16>(4), static_cast<bfloat16>(5)}),
|
||||
LiteralUtil::CreateR1<bfloat16>(
|
||||
{static_cast<bfloat16>(5), static_cast<bfloat16>(5)})
|
||||
.get()});
|
||||
{static_cast<bfloat16>(5), static_cast<bfloat16>(5)})});
|
||||
|
||||
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01, 0.02));
|
||||
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02));
|
||||
}
|
||||
|
||||
XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
|
||||
@ -139,21 +136,18 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
|
||||
BatchNormGrad(operand, scale, mean, var, grad_output,
|
||||
/*epsilon=*/0.0, kFeatureIndex);
|
||||
|
||||
auto expected = LiteralUtil::MakeTuple(
|
||||
auto expected = LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR4<bfloat16>(
|
||||
{{{{static_cast<bfloat16>(-3.f)}, {static_cast<bfloat16>(-3.f)}},
|
||||
{{static_cast<bfloat16>(-1.f)}, {static_cast<bfloat16>(-1.f)}}},
|
||||
{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(1.f)}},
|
||||
{{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}})
|
||||
.get(),
|
||||
{{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}}),
|
||||
LiteralUtil::CreateR1<bfloat16>(
|
||||
{static_cast<bfloat16>(0), static_cast<bfloat16>(0)})
|
||||
.get(),
|
||||
{static_cast<bfloat16>(0), static_cast<bfloat16>(0)}),
|
||||
LiteralUtil::CreateR1<bfloat16>(
|
||||
{static_cast<bfloat16>(16), static_cast<bfloat16>(20)})
|
||||
.get()});
|
||||
{static_cast<bfloat16>(16), static_cast<bfloat16>(20)})});
|
||||
|
||||
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01));
|
||||
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -60,10 +60,10 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
|
||||
float end, int seed) {
|
||||
*r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
|
||||
r3_array->FillRandom(start, end, seed);
|
||||
auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout(
|
||||
auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array).Relayout(
|
||||
LayoutUtil::MakeLayout(minor_to_major));
|
||||
std::unique_ptr<GlobalData> r3_global_data =
|
||||
client_->TransferToServer(*r3_data).ConsumeValueOrDie();
|
||||
client_->TransferToServer(r3_data).ConsumeValueOrDie();
|
||||
return r3_global_data;
|
||||
}
|
||||
|
||||
@ -74,10 +74,10 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
|
||||
float end, int seed) {
|
||||
*r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
|
||||
r2_array->FillRandom(start, end, seed);
|
||||
auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout(
|
||||
auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array).Relayout(
|
||||
LayoutUtil::MakeLayout(minor_to_major));
|
||||
std::unique_ptr<GlobalData> r2_global_data =
|
||||
client_->TransferToServer(*r2_data).ConsumeValueOrDie();
|
||||
client_->TransferToServer(r2_data).ConsumeValueOrDie();
|
||||
return r2_global_data;
|
||||
}
|
||||
|
||||
@ -293,7 +293,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
|
||||
XlaBuilder b(TestName());
|
||||
|
||||
Add(ConstantR2<float>(&b, {{1.0, 5.0}}),
|
||||
ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
|
||||
ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
|
||||
{{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
|
||||
/*broadcast_dimensions=*/{1, 2});
|
||||
|
||||
@ -301,7 +301,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
|
||||
LiteralUtil::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
|
||||
{{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
|
||||
|
||||
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
|
||||
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
struct R3ImplicitBroadcastSpec {
|
||||
@ -370,8 +370,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
|
||||
}
|
||||
auto expected = LiteralUtil::CreateR3FromArray3D(expected_array);
|
||||
ComputeAndCompareLiteral(
|
||||
&builder, *expected,
|
||||
{r3_implicit_global_data.get(), r3_global_data.get()},
|
||||
&builder, expected, {r3_implicit_global_data.get(), r3_global_data.get()},
|
||||
ErrorSpec(1e-7, 1e-7));
|
||||
}
|
||||
|
||||
@ -395,89 +394,89 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
|
||||
auto expected =
|
||||
LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
|
||||
|
||||
ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()},
|
||||
ComputeAndCompareLiteral(&b, expected, {r3.get(), r1.get()},
|
||||
ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) {
|
||||
XlaBuilder b(TestName());
|
||||
auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}}));
|
||||
auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}}));
|
||||
auto r3 = ConstantLiteral(
|
||||
&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
|
||||
&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
|
||||
Add(r3, r1);
|
||||
|
||||
auto expected =
|
||||
LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
|
||||
|
||||
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
|
||||
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) {
|
||||
XlaBuilder b(TestName());
|
||||
auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}, {2}}}));
|
||||
auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}, {2}}}));
|
||||
auto r3 = ConstantLiteral(
|
||||
&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
|
||||
&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
|
||||
Add(r3, r1);
|
||||
|
||||
auto expected =
|
||||
LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
|
||||
|
||||
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
|
||||
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
|
||||
XlaBuilder b(TestName());
|
||||
auto r1 =
|
||||
ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
|
||||
ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
|
||||
auto r3 = ConstantLiteral(
|
||||
&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
|
||||
&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
|
||||
Add(r3, r1);
|
||||
|
||||
auto expected =
|
||||
LiteralUtil::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
|
||||
|
||||
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
|
||||
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) {
|
||||
XlaBuilder b(TestName());
|
||||
auto r1 =
|
||||
ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
|
||||
ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
|
||||
auto r3 = ConstantLiteral(
|
||||
&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
|
||||
&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
|
||||
Add(r3, r1);
|
||||
|
||||
auto expected =
|
||||
LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
|
||||
|
||||
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
|
||||
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) {
|
||||
XlaBuilder b(TestName());
|
||||
auto r1 = ConstantLiteral(
|
||||
&b, *LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
|
||||
&b, LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
|
||||
auto r3 = ConstantLiteral(
|
||||
&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
|
||||
&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
|
||||
Add(r3, r1);
|
||||
|
||||
auto expected =
|
||||
LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
|
||||
|
||||
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
|
||||
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) {
|
||||
XlaBuilder b(TestName());
|
||||
auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}}}));
|
||||
auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}}}));
|
||||
auto r3 = ConstantLiteral(
|
||||
&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
|
||||
&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
|
||||
Add(r3, r1);
|
||||
|
||||
auto expected =
|
||||
LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
|
||||
|
||||
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
|
||||
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
struct R2ImplicitBroadcastSpec {
|
||||
@ -618,7 +617,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
|
||||
|
||||
auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
|
||||
ComputeAndCompareLiteral(
|
||||
&builder, *expected,
|
||||
&builder, expected,
|
||||
{r2_implicit_global_data1.get(), r2_global_data.get(),
|
||||
r2_implicit_global_data2.get()},
|
||||
ErrorSpec(1e-6, 1e-6));
|
||||
@ -630,65 +629,63 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances,
|
||||
|
||||
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) {
|
||||
XlaBuilder b(TestName());
|
||||
auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}}));
|
||||
auto r2 =
|
||||
ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
|
||||
auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}}));
|
||||
auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
|
||||
Add(r2, r1);
|
||||
|
||||
auto expected = LiteralUtil::CreateR2<float>({{2, 4}, {4, 6}});
|
||||
|
||||
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
|
||||
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) {
|
||||
XlaBuilder b(TestName());
|
||||
auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1}, {2}}));
|
||||
auto r2 =
|
||||
ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
|
||||
auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1}, {2}}));
|
||||
auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
|
||||
Add(r2, r1);
|
||||
|
||||
auto expected = LiteralUtil::CreateR2<float>({{2, 3}, {5, 6}});
|
||||
|
||||
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
|
||||
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) {
|
||||
XlaBuilder b(TestName());
|
||||
auto r1 = ConstantR1<float>(&b, {10, 20});
|
||||
auto r3 = ConstantLiteral(
|
||||
&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
|
||||
&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
|
||||
Add(r3, r1, {0});
|
||||
|
||||
auto expected = LiteralUtil::CreateR3<float>(
|
||||
{{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
|
||||
|
||||
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
|
||||
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) {
|
||||
XlaBuilder b(TestName());
|
||||
auto r1 = ConstantR1<float>(&b, {10, 20});
|
||||
auto r3 = ConstantLiteral(
|
||||
&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
|
||||
&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
|
||||
Add(r1, r3, {1});
|
||||
|
||||
auto expected = LiteralUtil::CreateR3<float>(
|
||||
{{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
|
||||
|
||||
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
|
||||
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) {
|
||||
XlaBuilder b(TestName());
|
||||
auto r1 = ConstantR1<float>(&b, {10, 20});
|
||||
auto r3 = ConstantLiteral(
|
||||
&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
|
||||
&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
|
||||
Add(r1, r3, {2});
|
||||
|
||||
auto expected = LiteralUtil::CreateR3<float>(
|
||||
{{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
|
||||
|
||||
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
|
||||
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
|
||||
@ -697,7 +694,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
|
||||
auto r1_1 = ConstantR1<float>(&b, {100, 200});
|
||||
auto r1_2 = ConstantR1<float>(&b, {10, 20});
|
||||
auto r3 = ConstantLiteral(
|
||||
&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
|
||||
&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
r3 = Add(r1_0, r3, {0});
|
||||
r3 = Add(r3, r1_1, {1});
|
||||
@ -709,7 +706,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
|
||||
{{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}},
|
||||
{{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}});
|
||||
|
||||
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
|
||||
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
|
||||
@ -730,7 +727,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
|
||||
{{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}},
|
||||
{{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}});
|
||||
|
||||
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
|
||||
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
|
||||
@ -739,7 +736,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
|
||||
XlaBuilder b(TestName());
|
||||
|
||||
Add(ConstantR2<float>(&b, {{1.0, 5.0}, {1.0, 5.0}}),
|
||||
ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
|
||||
ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
|
||||
{{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
|
||||
/*broadcast_dimensions=*/{1, 2});
|
||||
|
||||
|
@ -46,8 +46,8 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
|
||||
hlo_module->AddEntryComputation(builder.Build());
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR0<float>(42.0),
|
||||
*result, error_spec_));
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR0<float>(42.0), result,
|
||||
error_spec_));
|
||||
}
|
||||
|
||||
XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
|
||||
@ -63,7 +63,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||
*LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result,
|
||||
LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), result,
|
||||
error_spec_));
|
||||
}
|
||||
|
||||
@ -86,12 +86,12 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||
*LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
|
||||
LiteralSlice(*result, {0}), error_spec_));
|
||||
LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
|
||||
LiteralSlice(result, {0}), error_spec_));
|
||||
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||
*LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
|
||||
LiteralSlice(*result, {1}), error_spec_));
|
||||
LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
|
||||
LiteralSlice(result, {1}), error_spec_));
|
||||
}
|
||||
|
||||
XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
|
||||
@ -107,7 +107,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||
*LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), *result,
|
||||
LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), result,
|
||||
error_spec_));
|
||||
}
|
||||
|
||||
@ -126,7 +126,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||
*LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), *result,
|
||||
LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), result,
|
||||
error_spec_));
|
||||
}
|
||||
|
||||
@ -143,9 +143,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||
*LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
|
||||
{{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
|
||||
*result, error_spec_));
|
||||
LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
|
||||
{{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
|
||||
result, error_spec_));
|
||||
}
|
||||
|
||||
TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
|
||||
@ -166,9 +166,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
|
||||
Array2D<float> pz({{1, 2}, {1, 2}});
|
||||
expected.FillWithPZ(pz);
|
||||
|
||||
EXPECT_TRUE(
|
||||
LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
|
||||
*result, error_spec_));
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||
LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
|
||||
}
|
||||
|
||||
TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
|
||||
@ -197,9 +196,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
|
||||
}
|
||||
expected.FillWithYX(yx);
|
||||
|
||||
EXPECT_TRUE(
|
||||
LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
|
||||
*result, error_spec_));
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||
LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
|
||||
}
|
||||
|
||||
XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
|
||||
@ -220,8 +218,8 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
|
||||
hlo_module->AddEntryComputation(builder.Build());
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(r4_array),
|
||||
*result, error_spec_));
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(r4_array),
|
||||
result, error_spec_));
|
||||
}
|
||||
|
||||
TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
|
||||
@ -240,9 +238,8 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
|
||||
Array4D<float> expected(64, 64, 3, 3);
|
||||
expected.Fill(1.0f);
|
||||
|
||||
EXPECT_TRUE(
|
||||
LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
|
||||
*result, error_spec_));
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||
LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
|
||||
}
|
||||
|
||||
TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
|
||||
@ -263,9 +260,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
|
||||
Array4D<float> expected(3, 3, 2, 2);
|
||||
expected.FillWithYX(to_broadcast);
|
||||
|
||||
EXPECT_TRUE(
|
||||
LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
|
||||
*result, error_spec_));
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||
LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
|
||||
}
|
||||
|
||||
TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
|
||||
@ -295,9 +291,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
|
||||
hlo_module->AddEntryComputation(builder.Build());
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
|
||||
EXPECT_TRUE(
|
||||
LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
|
||||
*result, error_spec_));
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||
LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -77,8 +77,7 @@ class CallOpTest : public ClientLibraryTestBase {
|
||||
XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) {
|
||||
XlaBuilder builder(TestName());
|
||||
XlaComputation callee = CreateR0F32IdentityComputation();
|
||||
auto constant =
|
||||
ConstantLiteral(&builder, *LiteralUtil::CreateR0<float>(42.0));
|
||||
auto constant = ConstantLiteral(&builder, LiteralUtil::CreateR0<float>(42.0));
|
||||
Call(&builder, callee, {constant});
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, 42.0, {}, ErrorSpec(0.01f));
|
||||
@ -87,8 +86,8 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) {
|
||||
XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) {
|
||||
XlaBuilder builder(TestName());
|
||||
XlaComputation callee = CreateR1S0F32AdditionComputation();
|
||||
auto x = ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({}));
|
||||
auto y = ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({}));
|
||||
auto x = ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({}));
|
||||
auto y = ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({}));
|
||||
Call(&builder, callee, {x, y});
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.01f));
|
||||
@ -98,9 +97,9 @@ XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) {
|
||||
XlaBuilder builder(TestName());
|
||||
XlaComputation callee = CreateR1S2F32AdditionComputation();
|
||||
auto x =
|
||||
ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({1.0f, 2.0f}));
|
||||
ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({1.0f, 2.0f}));
|
||||
auto y =
|
||||
ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({2.0f, 3.0f}));
|
||||
ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({2.0f, 3.0f}));
|
||||
Call(&builder, callee, {x, y});
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f));
|
||||
@ -133,7 +132,7 @@ XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) {
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<GlobalData> start,
|
||||
client_->TransferToServer(*LiteralUtil::CreateR0<float>(1.0f)));
|
||||
client_->TransferToServer(LiteralUtil::CreateR0<float>(1.0f)));
|
||||
ComputeAndCompareR0<float>(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f));
|
||||
}
|
||||
|
||||
@ -141,10 +140,10 @@ XLA_TEST_F(CallOpTest, CallR0F32Tuple) {
|
||||
XlaBuilder builder(TestName());
|
||||
XlaComputation callee = CreateR0F32TupleComputation();
|
||||
auto elem = LiteralUtil::CreateR0<float>(42.0);
|
||||
auto tuple = LiteralUtil::MakeTuple({elem.get()});
|
||||
Call(&builder, callee, {ConstantLiteral(&builder, *elem)});
|
||||
auto tuple = LiteralUtil::MakeTuple({&elem});
|
||||
Call(&builder, callee, {ConstantLiteral(&builder, elem)});
|
||||
|
||||
ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f));
|
||||
ComputeAndCompareTuple(&builder, tuple, {}, ErrorSpec(0.01f));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -38,14 +38,14 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) {
|
||||
XlaBuilder builder("add_two_params");
|
||||
auto param_literal = LiteralUtil::CreateR1<float>({1.1f, 2.2f});
|
||||
|
||||
auto p0 = Parameter(&builder, 0, param_literal->shape(), "param0");
|
||||
auto p1 = Parameter(&builder, 1, param_literal->shape(), "param1");
|
||||
auto p0 = Parameter(&builder, 0, param_literal.shape(), "param0");
|
||||
auto p1 = Parameter(&builder, 1, param_literal.shape(), "param1");
|
||||
Add(p0, p1);
|
||||
|
||||
auto param0_data =
|
||||
client_->TransferToServer(*param_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(param_literal).ConsumeValueOrDie();
|
||||
auto param1_data =
|
||||
client_->TransferToServer(*param_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(param_literal).ConsumeValueOrDie();
|
||||
|
||||
auto computation_status = builder.Build();
|
||||
ASSERT_IS_OK(computation_status.status());
|
||||
@ -86,12 +86,12 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) {
|
||||
auto computation = computation_status.ConsumeValueOrDie();
|
||||
|
||||
auto f32_literal = LiteralUtil::CreateR0<float>(1.1f);
|
||||
auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie();
|
||||
auto f32_data = client_->TransferToServer(f32_literal).ConsumeValueOrDie();
|
||||
auto f32_4_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
|
||||
auto f32_4_data =
|
||||
client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(f32_4_literal).ConsumeValueOrDie();
|
||||
auto u8_4_literal = LiteralUtil::CreateR1U8("hola");
|
||||
auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie();
|
||||
auto u8_4_data = client_->TransferToServer(u8_4_literal).ConsumeValueOrDie();
|
||||
|
||||
// Match
|
||||
auto status = client_->Execute(
|
||||
|
@ -101,7 +101,7 @@ StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
|
||||
return client_->Execute(computation, arguments, &execution_options_);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
|
||||
StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
|
||||
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
|
||||
const Shape* shape_with_output_layout) {
|
||||
ExecutionOptions execution_options = execution_options_;
|
||||
@ -113,7 +113,7 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
|
||||
&execution_options);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
|
||||
StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
|
||||
XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
|
||||
const Shape* shape_with_output_layout) {
|
||||
// Build the computation, as a convenience.
|
||||
@ -121,8 +121,7 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
|
||||
return ExecuteAndTransfer(computation, arguments, shape_with_output_layout);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>>
|
||||
ClientLibraryTestBase::ExecuteAndTransferReference(
|
||||
StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransferReference(
|
||||
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
|
||||
const Shape* shape_with_output_layout) {
|
||||
ExecutionOptions execution_options = execution_options_;
|
||||
@ -148,15 +147,15 @@ string ClientLibraryTestBase::ExecuteToString(
|
||||
if (!result.ok()) {
|
||||
return result.status().ToString();
|
||||
} else {
|
||||
return result.ValueOrDie()->ToString();
|
||||
return result.ValueOrDie().ToString();
|
||||
}
|
||||
}
|
||||
|
||||
void ClientLibraryTestBase::ComputeAndCompareR1(
|
||||
XlaBuilder* builder, const tensorflow::core::Bitmap& expected,
|
||||
absl::Span<GlobalData* const> arguments) {
|
||||
std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
|
||||
Literal expected_literal = LiteralUtil::CreateR1(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
|
||||
arguments);
|
||||
}
|
||||
|
||||
@ -182,7 +181,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
|
||||
const string& error_message)>& verify_output) {
|
||||
// Try with no layout requirement.
|
||||
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments));
|
||||
verify_output(*actual, "");
|
||||
verify_output(actual, "");
|
||||
|
||||
// Try with all output layouts.
|
||||
std::vector<int64> minor_to_major(ShapeUtil::Rank(expected.shape()));
|
||||
@ -193,7 +192,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
|
||||
AsInt64Slice(expected.shape().dimensions()), minor_to_major);
|
||||
TF_ASSIGN_OR_RETURN(auto actual,
|
||||
ExecuteAndTransfer(computation, arguments, &layout));
|
||||
verify_output(*actual,
|
||||
verify_output(actual,
|
||||
absl::StrCat("Test with output layout: ",
|
||||
ShapeUtil::HumanStringWithLayout(layout)));
|
||||
} while (std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
|
||||
@ -218,9 +217,9 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
|
||||
TF_ASSIGN_OR_RETURN(auto literal,
|
||||
client_->Transfer(*arguments[index], nullptr));
|
||||
// Skip tuples because they don't have a rank.
|
||||
if (ShapeUtil::IsTuple(literal->shape())) {
|
||||
if (ShapeUtil::IsTuple(literal.shape())) {
|
||||
layout_strings.push_back(
|
||||
ShapeUtil::HumanStringWithLayout(literal->shape()));
|
||||
ShapeUtil::HumanStringWithLayout(literal.shape()));
|
||||
arguments_with_layout.push_back(arguments[index]);
|
||||
TF_RETURN_IF_ERROR(choose(index + 1));
|
||||
arguments_with_layout.pop_back();
|
||||
@ -228,15 +227,15 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<int64> minor_to_major(ShapeUtil::Rank(literal->shape()));
|
||||
std::vector<int64> minor_to_major(ShapeUtil::Rank(literal.shape()));
|
||||
std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
|
||||
do {
|
||||
auto literal_relayout =
|
||||
literal->Relayout(LayoutUtil::MakeLayout(minor_to_major));
|
||||
literal.Relayout(LayoutUtil::MakeLayout(minor_to_major));
|
||||
layout_strings.push_back(
|
||||
ShapeUtil::HumanStringWithLayout(literal_relayout->shape()));
|
||||
ShapeUtil::HumanStringWithLayout(literal_relayout.shape()));
|
||||
TF_ASSIGN_OR_RETURN(auto data,
|
||||
client_->TransferToServer(*literal_relayout));
|
||||
client_->TransferToServer(literal_relayout));
|
||||
arguments_with_layout.push_back(data.get());
|
||||
TF_RETURN_IF_ERROR(choose(index + 1));
|
||||
arguments_with_layout.pop_back();
|
||||
@ -256,7 +255,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
|
||||
for (const auto& str : layout_strings) {
|
||||
absl::StrAppend(&error_message, str, " ");
|
||||
}
|
||||
verify_output(*actual, error_message);
|
||||
verify_output(actual, error_message);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
@ -290,11 +289,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
|
||||
// We allow using a float expected literal for a bfloat16 output. In this
|
||||
// case, we need to convert the expected literal to bfloat16.
|
||||
const Literal* expected_ptr = &expected;
|
||||
std::unique_ptr<Literal> converted_expected;
|
||||
Literal converted_expected;
|
||||
Shape layout_shape;
|
||||
if (use_bfloat16_) {
|
||||
converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
|
||||
expected_ptr = converted_expected.get();
|
||||
expected_ptr = &converted_expected;
|
||||
if (shape_with_layout != nullptr) {
|
||||
layout_shape = *shape_with_layout;
|
||||
ShapeUtil::ForEachMutableSubshape(
|
||||
@ -319,7 +318,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
|
||||
shape_with_layout));
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual));
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -346,11 +345,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
|
||||
// We allow using a float expected literal for a bfloat16 output. In this
|
||||
// case, we need to convert the expected literal to bfloat16.
|
||||
const Literal* expected_ptr = &expected;
|
||||
std::unique_ptr<Literal> converted_expected;
|
||||
Literal converted_expected;
|
||||
Shape layout_shape;
|
||||
if (use_bfloat16_) {
|
||||
converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
|
||||
expected_ptr = converted_expected.get();
|
||||
expected_ptr = &converted_expected;
|
||||
if (shape_with_layout != nullptr) {
|
||||
layout_shape = *shape_with_layout;
|
||||
ShapeUtil::ForEachMutableSubshape(
|
||||
@ -376,7 +375,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
|
||||
shape_with_layout));
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error));
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -391,12 +390,12 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
|
||||
auto actual = actual_status.ConsumeValueOrDie();
|
||||
|
||||
// Turn the expected value into a literal.
|
||||
std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1U8(expected);
|
||||
Literal expected_literal = LiteralUtil::CreateR1U8(expected);
|
||||
|
||||
VLOG(1) << "expected: " << expected_literal->ToString();
|
||||
VLOG(1) << "actual: " << actual->ToString();
|
||||
VLOG(1) << "expected: " << expected_literal.ToString();
|
||||
VLOG(1) << "actual: " << actual.ToString();
|
||||
|
||||
EXPECT_EQ(expected, actual->GetR1U8AsString());
|
||||
EXPECT_EQ(expected, actual.GetR1U8AsString());
|
||||
}
|
||||
|
||||
void ClientLibraryTestBase::ComputeAndCompareTuple(
|
||||
@ -408,7 +407,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
|
||||
return;
|
||||
}
|
||||
auto actual = actual_status.ConsumeValueOrDie();
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual));
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
|
||||
}
|
||||
|
||||
void ClientLibraryTestBase::ComputeAndCompareTuple(
|
||||
@ -420,7 +419,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
|
||||
return;
|
||||
}
|
||||
auto actual = actual_status.ConsumeValueOrDie();
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error));
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, error));
|
||||
}
|
||||
|
||||
void ClientLibraryTestBase::ComputeAndCompare(
|
||||
@ -430,9 +429,9 @@ void ClientLibraryTestBase::ComputeAndCompare(
|
||||
if (!status_or_data.ok()) {
|
||||
return;
|
||||
}
|
||||
std::unique_ptr<Literal> reference, result;
|
||||
Literal reference, result;
|
||||
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result));
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(reference, result));
|
||||
}
|
||||
|
||||
void ClientLibraryTestBase::ComputeAndCompare(
|
||||
@ -442,12 +441,12 @@ void ClientLibraryTestBase::ComputeAndCompare(
|
||||
if (!status_or_data.ok()) {
|
||||
return;
|
||||
}
|
||||
std::unique_ptr<Literal> reference, result;
|
||||
Literal reference, result;
|
||||
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error));
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(reference, result, error));
|
||||
}
|
||||
|
||||
StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
|
||||
StatusOr<std::pair<Literal, Literal>>
|
||||
ClientLibraryTestBase::ComputeValueAndReference(
|
||||
XlaBuilder* builder, absl::Span<const Literal> arguments) {
|
||||
// Transfer the arguments to the executor service. We put the unique_ptr's
|
||||
@ -569,8 +568,8 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
|
||||
XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
|
||||
XlaBuilder* builder) {
|
||||
return ConstantLiteral(builder, use_bfloat16_
|
||||
? *LiteralUtil::ConvertF32ToBF16(literal)
|
||||
: literal);
|
||||
? LiteralUtil::ConvertF32ToBF16(literal)
|
||||
: LiteralSlice(literal));
|
||||
}
|
||||
|
||||
std::unique_ptr<GlobalData>
|
||||
@ -600,7 +599,7 @@ Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) {
|
||||
Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16(
|
||||
const Literal& literal) {
|
||||
if (use_bfloat16_) {
|
||||
return std::move(*LiteralUtil::ConvertF32ToBF16(literal));
|
||||
return LiteralUtil::ConvertF32ToBF16(literal);
|
||||
}
|
||||
return literal.Clone();
|
||||
}
|
||||
|
@ -95,11 +95,11 @@ class ClientLibraryTestBase : public ::testing::Test {
|
||||
StatusOr<std::unique_ptr<GlobalData>> Execute(
|
||||
XlaBuilder* builder, absl::Span<GlobalData* const> arguments);
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
|
||||
StatusOr<Literal> ExecuteAndTransfer(
|
||||
XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
|
||||
const Shape* shape_with_output_layout = nullptr);
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
|
||||
StatusOr<Literal> ExecuteAndTransfer(
|
||||
const XlaComputation& computation,
|
||||
absl::Span<GlobalData* const> arguments,
|
||||
const Shape* shape_with_output_layout = nullptr);
|
||||
@ -107,7 +107,7 @@ class ClientLibraryTestBase : public ::testing::Test {
|
||||
// This executes the computation via the reference client (which connects a
|
||||
// interpreter backend). The result is used as the expected values of the
|
||||
// computation.
|
||||
StatusOr<std::unique_ptr<Literal>> ExecuteAndTransferReference(
|
||||
StatusOr<Literal> ExecuteAndTransferReference(
|
||||
const XlaComputation& computation,
|
||||
absl::Span<GlobalData* const> arguments,
|
||||
const Shape* shape_with_output_layout = nullptr);
|
||||
@ -282,7 +282,7 @@ class ClientLibraryTestBase : public ::testing::Test {
|
||||
|
||||
template <class T>
|
||||
XlaOp AddParam(const Array<T>& argument, XlaBuilder* builder) {
|
||||
return AddParam(*LiteralUtil::CreateFromArray(argument), builder);
|
||||
return AddParam(LiteralUtil::CreateFromArray(argument), builder);
|
||||
}
|
||||
|
||||
// Creates a constant instruction with the given literal. When the
|
||||
@ -297,14 +297,14 @@ class ClientLibraryTestBase : public ::testing::Test {
|
||||
template <typename NativeT>
|
||||
XlaOp CreateConstantFromArray(const Array<NativeT>& array,
|
||||
XlaBuilder* builder) {
|
||||
return CreateConstantFromLiteral(*LiteralUtil::CreateFromArray(array),
|
||||
return CreateConstantFromLiteral(LiteralUtil::CreateFromArray(array),
|
||||
builder);
|
||||
}
|
||||
|
||||
// Same as CreateConstantFromArray, but for scalars.
|
||||
template <typename NativeT>
|
||||
XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) {
|
||||
return CreateConstantFromLiteral(*LiteralUtil::CreateR0<NativeT>(value),
|
||||
return CreateConstantFromLiteral(LiteralUtil::CreateR0<NativeT>(value),
|
||||
builder);
|
||||
}
|
||||
|
||||
@ -375,9 +375,8 @@ class ClientLibraryTestBase : public ::testing::Test {
|
||||
// Executes the computation and calculates the expected reference value using
|
||||
// the reference client. Returns two literals in the order of (expected,
|
||||
// actual).
|
||||
StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
|
||||
ComputeValueAndReference(XlaBuilder* builder,
|
||||
absl::Span<const Literal> arguments);
|
||||
StatusOr<std::pair<Literal, Literal>> ComputeValueAndReference(
|
||||
XlaBuilder* builder, absl::Span<const Literal> arguments);
|
||||
|
||||
Client* client_;
|
||||
Client* ref_client_; // To compute reference result.
|
||||
@ -412,9 +411,8 @@ template <typename NativeT>
|
||||
void ClientLibraryTestBase::ComputeAndCompareR0(
|
||||
XlaBuilder* builder, NativeT expected,
|
||||
absl::Span<GlobalData* const> arguments) {
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
LiteralUtil::CreateR0<NativeT>(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
|
||||
Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
|
||||
arguments);
|
||||
}
|
||||
|
||||
@ -428,9 +426,8 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
|
||||
std::is_same<NativeT, half>::value ||
|
||||
std::is_same<NativeT, complex64>::value,
|
||||
"Float or complex type required when specifying an ErrorSpec");
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
LiteralUtil::CreateR0<NativeT>(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
|
||||
Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
|
||||
arguments, error);
|
||||
}
|
||||
|
||||
@ -438,9 +435,8 @@ template <typename NativeT>
|
||||
void ClientLibraryTestBase::ComputeAndCompareR1(
|
||||
XlaBuilder* builder, absl::Span<const NativeT> expected,
|
||||
absl::Span<GlobalData* const> arguments) {
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
LiteralUtil::CreateR1<NativeT>(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
|
||||
Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
|
||||
arguments);
|
||||
}
|
||||
|
||||
@ -454,9 +450,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
|
||||
std::is_same<NativeT, half>::value ||
|
||||
std::is_same<NativeT, complex64>::value,
|
||||
"Float or complex type required when specifying an ErrorSpec");
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
LiteralUtil::CreateR1<NativeT>(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
|
||||
Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
|
||||
arguments, error);
|
||||
}
|
||||
|
||||
@ -464,9 +459,9 @@ template <typename NativeT>
|
||||
void ClientLibraryTestBase::ComputeAndCompareR2(
|
||||
XlaBuilder* builder, const Array2D<NativeT>& expected,
|
||||
absl::Span<GlobalData* const> arguments) {
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
Literal expected_literal =
|
||||
LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
|
||||
arguments);
|
||||
}
|
||||
|
||||
@ -480,9 +475,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2(
|
||||
std::is_same<NativeT, half>::value ||
|
||||
std::is_same<NativeT, complex64>::value,
|
||||
"Float or complex type required when specifying an ErrorSpec");
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
Literal expected_literal =
|
||||
LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
|
||||
arguments, error);
|
||||
}
|
||||
|
||||
@ -490,9 +485,9 @@ template <typename NativeT>
|
||||
void ClientLibraryTestBase::ComputeAndCompareR3(
|
||||
XlaBuilder* builder, const Array3D<NativeT>& expected,
|
||||
absl::Span<GlobalData* const> arguments) {
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
Literal expected_literal =
|
||||
LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
|
||||
arguments);
|
||||
}
|
||||
|
||||
@ -506,9 +501,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3(
|
||||
std::is_same<NativeT, half>::value ||
|
||||
std::is_same<NativeT, complex64>::value,
|
||||
"Float or complex type required when specifying an ErrorSpec");
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
Literal expected_literal =
|
||||
LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
|
||||
arguments, error);
|
||||
}
|
||||
|
||||
@ -516,9 +511,9 @@ template <typename NativeT>
|
||||
void ClientLibraryTestBase::ComputeAndCompareR4(
|
||||
XlaBuilder* builder, const Array4D<NativeT>& expected,
|
||||
absl::Span<GlobalData* const> arguments) {
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
Literal expected_literal =
|
||||
LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
|
||||
arguments);
|
||||
}
|
||||
|
||||
@ -532,9 +527,9 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
|
||||
std::is_same<NativeT, half>::value ||
|
||||
std::is_same<NativeT, complex64>::value,
|
||||
"Float or complex type required when specifying an ErrorSpec");
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
Literal expected_literal =
|
||||
LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
|
||||
arguments, error);
|
||||
}
|
||||
|
||||
@ -542,13 +537,13 @@ template <typename NativeT>
|
||||
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
|
||||
NativeT value, int64 parameter_number, const string& name,
|
||||
XlaBuilder* builder, XlaOp* data_handle) {
|
||||
std::unique_ptr<Literal> literal = LiteralUtil::CreateR0(value);
|
||||
if (use_bfloat16_ && literal->shape().element_type() == F32) {
|
||||
literal = LiteralUtil::ConvertF32ToBF16(*literal);
|
||||
Literal literal = LiteralUtil::CreateR0(value);
|
||||
if (use_bfloat16_ && literal.shape().element_type() == F32) {
|
||||
literal = LiteralUtil::ConvertF32ToBF16(literal);
|
||||
}
|
||||
std::unique_ptr<GlobalData> data =
|
||||
client_->TransferToServer(*literal).ConsumeValueOrDie();
|
||||
*data_handle = Parameter(builder, parameter_number, literal->shape(), name);
|
||||
client_->TransferToServer(literal).ConsumeValueOrDie();
|
||||
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
|
||||
return data;
|
||||
}
|
||||
|
||||
@ -556,13 +551,13 @@ template <typename NativeT>
|
||||
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
|
||||
absl::Span<const NativeT> values, int64 parameter_number,
|
||||
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
|
||||
std::unique_ptr<Literal> literal = LiteralUtil::CreateR1(values);
|
||||
if (use_bfloat16_ && literal->shape().element_type() == F32) {
|
||||
literal = LiteralUtil::ConvertF32ToBF16(*literal);
|
||||
Literal literal = LiteralUtil::CreateR1(values);
|
||||
if (use_bfloat16_ && literal.shape().element_type() == F32) {
|
||||
literal = LiteralUtil::ConvertF32ToBF16(literal);
|
||||
}
|
||||
std::unique_ptr<GlobalData> data =
|
||||
client_->TransferToServer(*literal).ConsumeValueOrDie();
|
||||
*data_handle = Parameter(builder, parameter_number, literal->shape(), name);
|
||||
client_->TransferToServer(literal).ConsumeValueOrDie();
|
||||
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
|
||||
return data;
|
||||
}
|
||||
|
||||
@ -570,13 +565,13 @@ template <typename NativeT>
|
||||
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
|
||||
const Array2D<NativeT>& array_2d, int64 parameter_number,
|
||||
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
|
||||
std::unique_ptr<Literal> literal = LiteralUtil::CreateR2FromArray2D(array_2d);
|
||||
if (use_bfloat16_ && literal->shape().element_type() == F32) {
|
||||
literal = LiteralUtil::ConvertF32ToBF16(*literal);
|
||||
Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d);
|
||||
if (use_bfloat16_ && literal.shape().element_type() == F32) {
|
||||
literal = LiteralUtil::ConvertF32ToBF16(literal);
|
||||
}
|
||||
std::unique_ptr<GlobalData> data =
|
||||
client_->TransferToServer(*literal).ConsumeValueOrDie();
|
||||
*data_handle = Parameter(builder, parameter_number, literal->shape(), name);
|
||||
client_->TransferToServer(literal).ConsumeValueOrDie();
|
||||
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
|
||||
return data;
|
||||
}
|
||||
|
||||
@ -584,13 +579,13 @@ template <typename NativeT>
|
||||
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
|
||||
const Array3D<NativeT>& array_3d, int64 parameter_number,
|
||||
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
|
||||
std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(array_3d);
|
||||
if (use_bfloat16_ && literal->shape().element_type() == F32) {
|
||||
literal = LiteralUtil::ConvertF32ToBF16(*literal);
|
||||
Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d);
|
||||
if (use_bfloat16_ && literal.shape().element_type() == F32) {
|
||||
literal = LiteralUtil::ConvertF32ToBF16(literal);
|
||||
}
|
||||
std::unique_ptr<GlobalData> data =
|
||||
client_->TransferToServer(*literal).ConsumeValueOrDie();
|
||||
*data_handle = Parameter(builder, parameter_number, literal->shape(), name);
|
||||
client_->TransferToServer(literal).ConsumeValueOrDie();
|
||||
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
|
||||
return data;
|
||||
}
|
||||
|
||||
|
@ -55,16 +55,15 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) {
|
||||
std::unique_ptr<GlobalData> data,
|
||||
client_->Execute(computation, {}, &execution_options));
|
||||
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
LiteralUtil::CreateR2WithLayout<int32>(
|
||||
{{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
|
||||
Literal expected_literal = LiteralUtil::CreateR2WithLayout<int32>(
|
||||
{{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto computed, client_->Transfer(*data, &expected_literal->shape()));
|
||||
auto computed, client_->Transfer(*data, &expected_literal.shape()));
|
||||
|
||||
ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
|
||||
expected_literal->shape(), computed->shape()));
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
|
||||
expected_literal.shape(), computed.shape()));
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -91,19 +90,19 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) {
|
||||
auto result,
|
||||
client_->ExecuteAndTransfer(computation, {}, &execution_options));
|
||||
LiteralTestUtil::ExpectR2Equal<int32>({{1, 2}, {3, 4}},
|
||||
LiteralSlice(*result, {0}));
|
||||
LiteralSlice(result, {0}));
|
||||
LiteralTestUtil::ExpectR2Equal<int32>({{10, 20}, {30, 40}},
|
||||
LiteralSlice(*result, {1}));
|
||||
LiteralSlice(result, {1}));
|
||||
|
||||
EXPECT_TRUE(ShapeUtil::IsTuple(result->shape()));
|
||||
EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape()));
|
||||
EXPECT_TRUE(ShapeUtil::IsTuple(result.shape()));
|
||||
EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.shape()));
|
||||
|
||||
EXPECT_TRUE(ShapeUtil::Equal(
|
||||
ShapeUtil::GetTupleElementShape(result->shape(), 0),
|
||||
ShapeUtil::GetTupleElementShape(result.shape(), 0),
|
||||
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
|
||||
/*minor_to_major=*/{0, 1})));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(
|
||||
ShapeUtil::GetTupleElementShape(result->shape(), 1),
|
||||
ShapeUtil::GetTupleElementShape(result.shape(), 1),
|
||||
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
|
||||
/*minor_to_major=*/{1, 0})));
|
||||
}
|
||||
@ -114,7 +113,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> const_arg,
|
||||
client_->TransferToServer(
|
||||
*LiteralUtil::CreateR2<int32>({{5, 6}, {7, 8}})));
|
||||
LiteralUtil::CreateR2<int32>({{5, 6}, {7, 8}})));
|
||||
|
||||
XlaBuilder b(TestName() + ".add");
|
||||
Add(Parameter(&b, 0, shape, "param_0"),
|
||||
@ -140,9 +139,9 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto result_literal,
|
||||
client_->Transfer(*results[0], &expected_result->shape()));
|
||||
client_->Transfer(*results[0], &expected_result.shape()));
|
||||
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal));
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected_result, result_literal));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -42,14 +42,14 @@ class CompilationCacheTest : public ClientLibraryTestBase {
|
||||
absl::Span<GlobalData* const> arguments,
|
||||
float expected_result, bool expect_cache_hit) {
|
||||
ExecutionProfile execution_profile;
|
||||
std::unique_ptr<Literal> result =
|
||||
Literal result =
|
||||
client_
|
||||
->ExecuteAndTransfer(computation, arguments,
|
||||
/*execution_options=*/&execution_options_,
|
||||
&execution_profile)
|
||||
.ConsumeValueOrDie();
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||
*LiteralUtil::CreateR0<float>(expected_result), *result, error_spec_));
|
||||
LiteralUtil::CreateR0<float>(expected_result), result, error_spec_));
|
||||
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
|
||||
}
|
||||
|
||||
@ -63,10 +63,9 @@ class CompilationCacheTest : public ClientLibraryTestBase {
|
||||
->Execute(computation, arguments,
|
||||
&execution_options_, &execution_profile)
|
||||
.ConsumeValueOrDie();
|
||||
std::unique_ptr<Literal> result =
|
||||
client_->Transfer(*data_handle).ConsumeValueOrDie();
|
||||
Literal result = client_->Transfer(*data_handle).ConsumeValueOrDie();
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||
*LiteralUtil::CreateR2<float>(expected_result), *result, error_spec_));
|
||||
LiteralUtil::CreateR2<float>(expected_result), result, error_spec_));
|
||||
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
|
||||
}
|
||||
|
||||
@ -88,13 +87,13 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) {
|
||||
XLA_TEST_F(CompilationCacheTest,
|
||||
DISABLED_ComputationCalledWithDifferentParameters) {
|
||||
std::unique_ptr<GlobalData> data_42 =
|
||||
client_->TransferToServer(*LiteralUtil::CreateR0<float>(42.0f))
|
||||
client_->TransferToServer(LiteralUtil::CreateR0<float>(42.0f))
|
||||
.ConsumeValueOrDie();
|
||||
std::unique_ptr<GlobalData> data_123 =
|
||||
client_->TransferToServer(*LiteralUtil::CreateR0<float>(123.0f))
|
||||
client_->TransferToServer(LiteralUtil::CreateR0<float>(123.0f))
|
||||
.ConsumeValueOrDie();
|
||||
std::unique_ptr<GlobalData> data_456 =
|
||||
client_->TransferToServer(*LiteralUtil::CreateR0<float>(456.0f))
|
||||
client_->TransferToServer(LiteralUtil::CreateR0<float>(456.0f))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
@ -145,12 +144,12 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) {
|
||||
auto rowmaj_array = LiteralUtil::CreateR2WithLayout(
|
||||
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0}));
|
||||
auto rowmaj_handle =
|
||||
client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
|
||||
client_->TransferToServer(rowmaj_array).ConsumeValueOrDie();
|
||||
|
||||
auto colmaj_array = LiteralUtil::CreateR2WithLayout(
|
||||
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}));
|
||||
auto colmaj_handle =
|
||||
client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();
|
||||
client_->TransferToServer(colmaj_array).ConsumeValueOrDie();
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0");
|
||||
|
@ -69,9 +69,9 @@ class ComputeConstantTest : public ::testing::Test {
|
||||
LOG(FATAL) << "invalid client_type value";
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> ComputeConstantLiteral(
|
||||
Client* client, const XlaOp& operand, XlaBuilder* builder,
|
||||
Layout* output_layout = nullptr) {
|
||||
StatusOr<Literal> ComputeConstantLiteral(Client* client, const XlaOp& operand,
|
||||
XlaBuilder* builder,
|
||||
Layout* output_layout = nullptr) {
|
||||
TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand));
|
||||
TF_ASSIGN_OR_RETURN(auto computed,
|
||||
client->ComputeConstant(subgraph, output_layout));
|
||||
@ -83,7 +83,7 @@ class ComputeConstantTest : public ::testing::Test {
|
||||
XlaBuilder* builder) {
|
||||
TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand,
|
||||
builder, nullptr));
|
||||
return literal->Get<Scalar>({});
|
||||
return literal.Get<Scalar>({});
|
||||
}
|
||||
|
||||
bool IsConstant(const XlaOp& operand, XlaBuilder* builder) {
|
||||
@ -206,9 +206,8 @@ TEST_F(ComputeConstantTest, NonScalarAdd) {
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto computed,
|
||||
ComputeConstantLiteral(client, computation, &b));
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
LiteralUtil::CreateR1<int32>({4, 6});
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
|
||||
Literal expected_literal = LiteralUtil::CreateR1<int32>({4, 6});
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
|
||||
}
|
||||
}
|
||||
|
||||
@ -221,8 +220,8 @@ TEST_F(ComputeConstantTest, IntegerDivide) {
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto computed,
|
||||
ComputeConstantLiteral(client, computation, &b));
|
||||
std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR0<int32>(5);
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
|
||||
Literal expected_literal = LiteralUtil::CreateR0<int32>(5);
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
|
||||
}
|
||||
}
|
||||
|
||||
@ -241,12 +240,11 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
|
||||
ConstantR2<int32>(&b, {{10, 20}, {30, 40}})),
|
||||
&b, &layout_proto));
|
||||
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
LiteralUtil::CreateR2WithLayout<int32>(
|
||||
{{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
|
||||
Literal expected_literal = LiteralUtil::CreateR2WithLayout<int32>(
|
||||
{{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
|
||||
ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
|
||||
expected_literal->shape(), computed->shape()));
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
|
||||
expected_literal.shape(), computed.shape()));
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -536,8 +536,8 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) {
|
||||
auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
|
||||
auto x_literal = LiteralUtil::CreateR0<float>(2.f);
|
||||
auto y_literal = LiteralUtil::CreateR0<float>(3.f);
|
||||
auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
|
||||
auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
|
||||
auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
|
||||
auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = Parameter(&builder, 0, f32_scalar, "x");
|
||||
@ -559,12 +559,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) {
|
||||
auto x_literal = LiteralUtil::CreateR1<float>({2.0f, 3.0f, 5.0f, 6.0f});
|
||||
auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
|
||||
auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
|
||||
auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
|
||||
auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
|
||||
auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
|
||||
auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
|
||||
auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
|
||||
auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = Parameter(&builder, 0, x_literal->shape(), "x");
|
||||
auto x = Parameter(&builder, 0, x_literal.shape(), "x");
|
||||
auto y = Parameter(&builder, 1, f32_scalar, "y");
|
||||
auto z = Parameter(&builder, 2, f32_scalar, "z");
|
||||
auto bcast = Broadcast(y, {5});
|
||||
@ -587,12 +587,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) {
|
||||
auto x_literal = LiteralUtil::CreateR3FromArray3D<float>(x3d);
|
||||
auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
|
||||
auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
|
||||
auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
|
||||
auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
|
||||
auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
|
||||
auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
|
||||
auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
|
||||
auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = Parameter(&builder, 0, x_literal->shape(), "x");
|
||||
auto x = Parameter(&builder, 0, x_literal.shape(), "x");
|
||||
auto y = Parameter(&builder, 1, f32_scalar, "y");
|
||||
auto z = Parameter(&builder, 2, f32_scalar, "y");
|
||||
auto y_bcast = Broadcast(y, {1, 5, 7});
|
||||
|
@ -359,8 +359,8 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
|
||||
|
||||
ComputeAndCompareTuple(
|
||||
&builder,
|
||||
*LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(12.0f).get(),
|
||||
LiteralUtil::CreateR0<float>(25.0f).get()}),
|
||||
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<float>(12.0f),
|
||||
LiteralUtil::CreateR0<float>(25.0f)}),
|
||||
{pred_arg.get()}, error_spec_);
|
||||
}
|
||||
|
||||
@ -375,12 +375,11 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
|
||||
Conditional(pred, operands, CreateR1TupleCeilComputation(), operands,
|
||||
CreateR1TupleFloorComputation());
|
||||
|
||||
ComputeAndCompareTuple(
|
||||
&builder,
|
||||
*LiteralUtil::MakeTuple(
|
||||
{LiteralUtil::CreateR1<float>({13.0f, 16.0f}).get(),
|
||||
LiteralUtil::CreateR1<float>({26.0f, 30.0f}).get()}),
|
||||
{pred_arg.get()}, error_spec_);
|
||||
ComputeAndCompareTuple(&builder,
|
||||
LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR1<float>({13.0f, 16.0f}),
|
||||
LiteralUtil::CreateR1<float>({26.0f, 30.0f})}),
|
||||
{pred_arg.get()}, error_spec_);
|
||||
}
|
||||
|
||||
// Test true and false computations that return a tuple of a predicate, a
|
||||
@ -415,13 +414,12 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
|
||||
Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
|
||||
false_builder_result.ConsumeValueOrDie());
|
||||
|
||||
ComputeAndCompareTuple(
|
||||
&builder,
|
||||
*LiteralUtil::MakeTuple(
|
||||
{LiteralUtil::CreateR0<bool>(true).get(),
|
||||
LiteralUtil::CreateR0<float>(12.2f).get(),
|
||||
LiteralUtil::CreateR1<float>({12.8f, 14.6f}).get()}),
|
||||
{pred_arg.get()}, error_spec_);
|
||||
ComputeAndCompareTuple(&builder,
|
||||
LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR0<bool>(true),
|
||||
LiteralUtil::CreateR0<float>(12.2f),
|
||||
LiteralUtil::CreateR1<float>({12.8f, 14.6f})}),
|
||||
{pred_arg.get()}, error_spec_);
|
||||
}
|
||||
|
||||
// Test true and false computations that return a nested tuple.
|
||||
@ -463,15 +461,13 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
|
||||
|
||||
ComputeAndCompareTuple(
|
||||
&builder,
|
||||
*LiteralUtil::MakeTuple(
|
||||
{LiteralUtil::MakeTuple(
|
||||
{LiteralUtil::CreateR0<float>(46.6f).get(),
|
||||
LiteralUtil::CreateR1<float>({54.4f, 58.4f}).get()})
|
||||
.get(),
|
||||
LiteralUtil::MakeTuple(
|
||||
{LiteralUtil::CreateR1<float>({62.1f, 67.4f}).get(),
|
||||
LiteralUtil::CreateR0<float>(9.3f).get()})
|
||||
.get()}),
|
||||
LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR0<float>(46.6f),
|
||||
LiteralUtil::CreateR1<float>({54.4f, 58.4f})}),
|
||||
LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR1<float>({62.1f, 67.4f}),
|
||||
LiteralUtil::CreateR0<float>(9.3f)})}),
|
||||
{pred_arg.get()}, error_spec_);
|
||||
}
|
||||
|
||||
@ -633,8 +629,8 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
|
||||
|
||||
ComputeAndCompareTuple(
|
||||
&builder,
|
||||
*LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(a).get(),
|
||||
LiteralUtil::CreateR0<float>(b).get()}),
|
||||
LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR0<float>(a), LiteralUtil::CreateR0<float>(b)}),
|
||||
{x_arg.get(), y_arg.get()}, error_spec_);
|
||||
};
|
||||
|
||||
@ -669,10 +665,10 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
|
||||
{
|
||||
// Pred is true case.
|
||||
std::vector<Literal> args;
|
||||
args.push_back(std::move(
|
||||
*LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(123).get(),
|
||||
LiteralUtil::CreateR0<int32>(-42).get()})));
|
||||
args.push_back(std::move(*LiteralUtil::CreateR0<bool>(true)));
|
||||
args.push_back(
|
||||
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
|
||||
LiteralUtil::CreateR0<int32>(-42)}));
|
||||
args.push_back(LiteralUtil::CreateR0<bool>(true));
|
||||
XlaBuilder builder(TestName() + ".main");
|
||||
auto p = Parameter(&builder, 0, tuple2, "p0");
|
||||
auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
|
||||
@ -682,10 +678,10 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
|
||||
{
|
||||
// Pred is false case.
|
||||
std::vector<Literal> args;
|
||||
args.push_back(std::move(
|
||||
*LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(123).get(),
|
||||
LiteralUtil::CreateR0<int32>(-42).get()})));
|
||||
args.push_back(std::move(*LiteralUtil::CreateR0<bool>(false)));
|
||||
args.push_back(
|
||||
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
|
||||
LiteralUtil::CreateR0<int32>(-42)}));
|
||||
args.push_back(LiteralUtil::CreateR0<bool>(false));
|
||||
XlaBuilder builder(TestName() + ".main");
|
||||
auto p = Parameter(&builder, 0, tuple2, "p0");
|
||||
auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
|
||||
|
@ -110,7 +110,7 @@ TEST_F(ConstantsTest, Small_2x2) {
|
||||
|
||||
TEST_F(ConstantsTest, Empty_3x0x2) {
|
||||
XlaBuilder builder(TestName());
|
||||
ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D<float>(
|
||||
ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(
|
||||
Array3D<float>(3, 0, 2)));
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 2), {});
|
||||
@ -126,7 +126,7 @@ TEST_F(ConstantsTest, Small_2x2x2) {
|
||||
{{5.f, 6.f}, // y0
|
||||
{7.f, 8.f}}, // y1
|
||||
});
|
||||
ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D<float>(array3d));
|
||||
ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(array3d));
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, array3d, {});
|
||||
}
|
||||
@ -140,12 +140,11 @@ TEST_F(ConstantsTest, Small_3x2x1x1) {
|
||||
{5.0f, 4.4f}, // p2
|
||||
});
|
||||
input_array.FillWithPZ(pz);
|
||||
std::unique_ptr<Literal> input_literal =
|
||||
LiteralUtil::CreateR4FromArray4D(input_array);
|
||||
Literal input_literal = LiteralUtil::CreateR4FromArray4D(input_array);
|
||||
|
||||
{
|
||||
XlaBuilder builder(TestName());
|
||||
ConstantLiteral(&builder, *input_literal);
|
||||
ConstantLiteral(&builder, input_literal);
|
||||
ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
|
||||
}
|
||||
|
||||
@ -159,23 +158,21 @@ TEST_F(ConstantsTest, Small_3x2x1x1) {
|
||||
// TODO(b/29263943): Support tuple constants.
|
||||
TEST_F(ConstantsTest, DISABLED_TupleConstant) {
|
||||
XlaBuilder builder(TestName());
|
||||
ConstantLiteral(&builder,
|
||||
*LiteralUtil::MakeTuple(
|
||||
{LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
|
||||
LiteralUtil::CreateR1<float>({2.0, 42}).get()}));
|
||||
ConstantLiteral(&builder, LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
|
||||
LiteralUtil::CreateR1<float>({2.0, 42})}));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
|
||||
Literal result = ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
|
||||
|
||||
LiteralTestUtil::ExpectR2Near<float>({{1.0}, {2.0}},
|
||||
LiteralSlice(*result, {0}), error_spec_);
|
||||
LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, LiteralSlice(*result, {1}),
|
||||
LiteralSlice(result, {0}), error_spec_);
|
||||
LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, LiteralSlice(result, {1}),
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ConstantsTest, Token) {
|
||||
XlaBuilder builder(TestName());
|
||||
ConstantLiteral(&builder, *LiteralUtil::CreateToken());
|
||||
ConstantLiteral(&builder, LiteralUtil::CreateToken());
|
||||
// TODO(b/80000000): tokens cannot be returned from computations.
|
||||
Tuple(&builder, {});
|
||||
TF_ASSERT_OK(Execute(&builder, {}).status());
|
||||
|
@ -210,10 +210,10 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) {
|
||||
static_cast<int64>(0x8000008000000000LL),
|
||||
static_cast<int64>(0x8000010000000000LL),
|
||||
};
|
||||
std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int64>({arg});
|
||||
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
|
||||
Literal arg_literal = LiteralUtil::CreateR1<int64>({arg});
|
||||
auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
|
||||
std::unique_ptr<GlobalData> arg_data =
|
||||
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(arg_literal).ConsumeValueOrDie();
|
||||
|
||||
ConvertElementType(arg_param, F32);
|
||||
|
||||
@ -229,10 +229,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) {
|
||||
std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff,
|
||||
0x80000000, 0x80000001, 0x80000002, 0x80000003,
|
||||
0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF};
|
||||
std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
|
||||
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
|
||||
Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
|
||||
auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
|
||||
std::unique_ptr<GlobalData> arg_data =
|
||||
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(arg_literal).ConsumeValueOrDie();
|
||||
|
||||
ConvertElementType(arg_param, F32);
|
||||
|
||||
@ -247,10 +247,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
|
||||
XlaBuilder builder(TestName());
|
||||
std::vector<float> arg{0.0f, 1.0f, 16777216.0f,
|
||||
16777218.0f, 2147483647.0f, 4294967040.0f};
|
||||
std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
|
||||
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
|
||||
Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
|
||||
auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
|
||||
std::unique_ptr<GlobalData> arg_data =
|
||||
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(arg_literal).ConsumeValueOrDie();
|
||||
|
||||
ConvertElementType(arg_param, U32);
|
||||
|
||||
@ -264,10 +264,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
|
||||
XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
|
||||
XlaBuilder builder(TestName());
|
||||
std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF};
|
||||
std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
|
||||
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
|
||||
Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
|
||||
auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
|
||||
std::unique_ptr<GlobalData> arg_data =
|
||||
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(arg_literal).ConsumeValueOrDie();
|
||||
|
||||
ConvertElementType(arg_param, S64);
|
||||
|
||||
@ -281,10 +281,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
|
||||
XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) {
|
||||
XlaBuilder builder(TestName());
|
||||
std::vector<int32> arg{0, 1, 0x1000, -1, -0x1000};
|
||||
std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int32>({arg});
|
||||
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
|
||||
Literal arg_literal = LiteralUtil::CreateR1<int32>({arg});
|
||||
auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
|
||||
std::unique_ptr<GlobalData> arg_data =
|
||||
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(arg_literal).ConsumeValueOrDie();
|
||||
|
||||
ConvertElementType(arg_param, S64);
|
||||
|
||||
@ -318,10 +318,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) {
|
||||
9223370937343148032.f,
|
||||
-9223371487098961920.f,
|
||||
-9223370937343148032.f};
|
||||
std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
|
||||
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
|
||||
Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
|
||||
auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
|
||||
std::unique_ptr<GlobalData> arg_data =
|
||||
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(arg_literal).ConsumeValueOrDie();
|
||||
|
||||
ConvertElementType(arg_param, S64);
|
||||
|
||||
@ -456,7 +456,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) {
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<GlobalData> dot_lhs_handle,
|
||||
client_->TransferToServer(*LiteralUtil::CreateR1<half>(input)));
|
||||
client_->TransferToServer(LiteralUtil::CreateR1<half>(input)));
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
ConvertElementType(
|
||||
@ -476,7 +476,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) {
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<GlobalData> dot_lhs_handle,
|
||||
client_->TransferToServer(*LiteralUtil::CreateR1<float>(input)));
|
||||
client_->TransferToServer(LiteralUtil::CreateR1<float>(input)));
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
ConvertElementType(
|
||||
|
@ -93,8 +93,7 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest,
|
||||
auto weight_array = absl::make_unique<Array4D<float>>(4, 3, 1, 1);
|
||||
weight_array->FillWithMultiples(0.2);
|
||||
auto weight_data =
|
||||
client_
|
||||
->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array))
|
||||
client_->TransferToServer(LiteralUtil::CreateR4FromArray4D(*weight_array))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
|
@ -123,8 +123,8 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest {
|
||||
}));
|
||||
|
||||
ComputeAndCompare(&builder,
|
||||
{std::move(*LiteralUtil::CreateFromArray(input_data)),
|
||||
std::move(*LiteralUtil::CreateFromArray(filter_data))},
|
||||
{LiteralUtil::CreateFromArray(input_data),
|
||||
LiteralUtil::CreateFromArray(filter_data)},
|
||||
error_spec_);
|
||||
}
|
||||
};
|
||||
@ -157,8 +157,8 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest {
|
||||
{7.0f, 8.0f},
|
||||
}));
|
||||
ComputeAndCompare(&builder,
|
||||
{std::move(*LiteralUtil::CreateFromArray(input_data)),
|
||||
std::move(*LiteralUtil::CreateFromArray(filter_data))},
|
||||
{LiteralUtil::CreateFromArray(input_data),
|
||||
LiteralUtil::CreateFromArray(filter_data)},
|
||||
error_spec_);
|
||||
}
|
||||
};
|
||||
@ -192,8 +192,8 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest {
|
||||
}));
|
||||
|
||||
ComputeAndCompare(&builder,
|
||||
{std::move(*LiteralUtil::CreateFromArray(input_data)),
|
||||
std::move(*LiteralUtil::CreateFromArray(filter_data))},
|
||||
{LiteralUtil::CreateFromArray(input_data),
|
||||
LiteralUtil::CreateFromArray(filter_data)},
|
||||
error_spec_);
|
||||
}
|
||||
};
|
||||
@ -224,8 +224,8 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest {
|
||||
{{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}}));
|
||||
// clang-format on
|
||||
ComputeAndCompare(&builder,
|
||||
{std::move(*LiteralUtil::CreateFromArray(input_data)),
|
||||
std::move(*LiteralUtil::CreateFromArray(filter_data))},
|
||||
{LiteralUtil::CreateFromArray(input_data),
|
||||
LiteralUtil::CreateFromArray(filter_data)},
|
||||
error_spec_);
|
||||
}
|
||||
};
|
||||
@ -249,10 +249,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
|
||||
Array3D<float> expected({{{510, 610, 710, 810}}});
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, expected,
|
||||
@ -284,10 +284,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
|
||||
Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR3<T>(&builder, expected,
|
||||
@ -319,10 +319,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
|
||||
Array3D<float> expected({{{190, 320, 230, 380, 270, 440, 310, 500}}});
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, expected,
|
||||
@ -350,10 +350,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
|
||||
Array3D<float> expected({{{510, 0, 610, 0, 710, 0, 810}}});
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, expected,
|
||||
@ -386,10 +386,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
|
||||
{{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR3<T>(&builder, expected,
|
||||
@ -435,23 +435,23 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
|
||||
std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
|
||||
iota(input_elems.begin(), input_elems.end(), 1.0f);
|
||||
auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
|
||||
auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
|
||||
auto input_r5 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
|
||||
|
||||
std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
|
||||
iota(filter_elems.begin(), filter_elems.end(), 1.0f);
|
||||
auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
|
||||
auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
|
||||
auto filter_r5 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
|
||||
|
||||
auto expected_r1 = LiteralUtil::CreateR1<float>(
|
||||
{19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446,
|
||||
38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470});
|
||||
auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie();
|
||||
auto expected_r5 = expected_r1.Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie();
|
||||
|
||||
auto input_literal = client_->TransferToServer(*input_r5).ConsumeValueOrDie();
|
||||
auto input_literal = client_->TransferToServer(input_r5).ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*filter_r5).ConsumeValueOrDie();
|
||||
client_->TransferToServer(filter_r5).ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareLiteral(&builder, *expected_r5,
|
||||
ComputeAndCompareLiteral(&builder, expected_r5,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
@ -498,23 +498,23 @@ class Convolve2D_1x3x3x5_3x3x5x3_Valid : public ConvolutionTest {
|
||||
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
|
||||
iota_int_init_value(input_elems, 1);
|
||||
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
|
||||
auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
|
||||
auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
|
||||
|
||||
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
|
||||
iota_int_init_value(filter_elems, 1);
|
||||
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
|
||||
auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
|
||||
auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
|
||||
|
||||
auto expected_r1 = LiteralUtil::CreateR1<T>(
|
||||
{static_cast<T>(92115), static_cast<T>(93150), static_cast<T>(94185)});
|
||||
auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
|
||||
auto expected_r4 = expected_r1.Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(*input_r4).ConsumeValueOrDie();
|
||||
client_->TransferToServer(input_r4).ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
|
||||
client_->TransferToServer(filter_r4).ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareLiteral(&builder, *expected_r4,
|
||||
ComputeAndCompareLiteral(&builder, expected_r4,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
@ -558,12 +558,12 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest {
|
||||
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
|
||||
iota_int_init_value(input_elems, 1);
|
||||
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
|
||||
auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
|
||||
auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
|
||||
|
||||
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
|
||||
iota_int_init_value(filter_elems, 1);
|
||||
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
|
||||
auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
|
||||
auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
|
||||
|
||||
auto expected_r1 = LiteralUtil::CreateR1<T>(
|
||||
{static_cast<T>(16029), static_cast<T>(16218), static_cast<T>(16407),
|
||||
@ -571,14 +571,14 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest {
|
||||
static_cast<T>(18369), static_cast<T>(18576), static_cast<T>(18783),
|
||||
static_cast<T>(19620), static_cast<T>(19836), static_cast<T>(20052),
|
||||
static_cast<T>(20925), static_cast<T>(21150), static_cast<T>(21375)});
|
||||
auto expected_r4 = expected_r1->Reshape({1, 1, 1, 15}).ConsumeValueOrDie();
|
||||
auto expected_r4 = expected_r1.Reshape({1, 1, 1, 15}).ConsumeValueOrDie();
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(*input_r4).ConsumeValueOrDie();
|
||||
client_->TransferToServer(input_r4).ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
|
||||
client_->TransferToServer(filter_r4).ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareLiteral(&builder, *expected_r4,
|
||||
ComputeAndCompareLiteral(&builder, expected_r4,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
@ -624,26 +624,26 @@ class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest {
|
||||
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
|
||||
iota_int_init_value(input_elems, 1);
|
||||
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
|
||||
auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
|
||||
auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
|
||||
|
||||
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
|
||||
iota_int_init_value(filter_elems, 1);
|
||||
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
|
||||
auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
|
||||
auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
|
||||
|
||||
auto expected_r1 = LiteralUtil::CreateR1<T>(
|
||||
{static_cast<T>(5076), static_cast<T>(5160), static_cast<T>(5244),
|
||||
static_cast<T>(5328), static_cast<T>(6164), static_cast<T>(6264),
|
||||
static_cast<T>(6364), static_cast<T>(6464), static_cast<T>(7380),
|
||||
static_cast<T>(7496), static_cast<T>(7612), static_cast<T>(7728)});
|
||||
auto expected_r4 = expected_r1->Reshape({1, 1, 1, 12}).ConsumeValueOrDie();
|
||||
auto expected_r4 = expected_r1.Reshape({1, 1, 1, 12}).ConsumeValueOrDie();
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(*input_r4).ConsumeValueOrDie();
|
||||
client_->TransferToServer(input_r4).ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
|
||||
client_->TransferToServer(filter_r4).ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareLiteral(&builder, *expected_r4,
|
||||
ComputeAndCompareLiteral(&builder, expected_r4,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
@ -692,8 +692,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
|
||||
expected_result.Fill(0);
|
||||
|
||||
ComputeAndCompare(&builder,
|
||||
{std::move(*LiteralUtil::CreateFromArray(param0)),
|
||||
std::move(*LiteralUtil::CreateFromArray(param1))},
|
||||
{LiteralUtil::CreateFromArray(param0),
|
||||
LiteralUtil::CreateFromArray(param1)},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
@ -749,26 +749,25 @@ class Convolve1D1WindowTestBase
|
||||
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
|
||||
static_cast<T>(1.0f));
|
||||
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
|
||||
auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
|
||||
auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
|
||||
|
||||
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
|
||||
static_cast<T>(1.0f));
|
||||
|
||||
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
|
||||
auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
|
||||
auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
|
||||
|
||||
std::vector<T> expect_elems(batch * output_feature * num_windows,
|
||||
static_cast<T>(window_size * input_feature));
|
||||
auto expected_r1 = LiteralUtil::CreateR1<T>(expect_elems);
|
||||
auto expected_r3 =
|
||||
expected_r1->Reshape({batch, num_windows, output_feature})
|
||||
.ConsumeValueOrDie();
|
||||
auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature})
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(*input_r3).ConsumeValueOrDie();
|
||||
client_->TransferToServer(input_r3).ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*filter_r3).ConsumeValueOrDie();
|
||||
ComputeAndCompareLiteral(&builder, *expected_r3,
|
||||
client_->TransferToServer(filter_r3).ConsumeValueOrDie();
|
||||
ComputeAndCompareLiteral(&builder, expected_r3,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
@ -868,8 +867,8 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
|
||||
}));
|
||||
|
||||
ComputeAndCompare(&builder,
|
||||
{std::move(*LiteralUtil::CreateFromArray(input_data)),
|
||||
std::move(*LiteralUtil::CreateFromArray(filter_data))},
|
||||
{LiteralUtil::CreateFromArray(input_data),
|
||||
LiteralUtil::CreateFromArray(filter_data)},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
@ -891,9 +890,8 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) {
|
||||
Array4D<float> filter_data(1, 1, 1, 2);
|
||||
filter_data.FillIota(10);
|
||||
|
||||
ComputeAndCompare(&builder,
|
||||
{std::move(*LiteralUtil::CreateFromArray(input_data)),
|
||||
std::move(*LiteralUtil::CreateFromArray(filter_data))});
|
||||
ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data),
|
||||
LiteralUtil::CreateFromArray(filter_data)});
|
||||
}
|
||||
|
||||
XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) {
|
||||
@ -928,8 +926,7 @@ XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) {
|
||||
/*padding=*/{{3, 3}, {3, 3}}, /*dimension_numbers=*/dnums,
|
||||
/*feature_group_count=*/64);
|
||||
|
||||
ComputeAndCompare(&builder,
|
||||
{std::move(*LiteralUtil::CreateFromArray(input_data))},
|
||||
ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data)},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
|
@ -1335,23 +1335,23 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) {
|
||||
|
||||
auto gradients_flat = LiteralUtil::CreateR1<float>({1});
|
||||
auto gradients_literal =
|
||||
gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
|
||||
auto gradients = ConstantLiteral(&builder, *gradients_literal);
|
||||
gradients_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
|
||||
auto gradients = ConstantLiteral(&builder, gradients_literal);
|
||||
|
||||
auto weights_flat = LiteralUtil::CreateR1<float>({1, 10, 100});
|
||||
auto weights_literal =
|
||||
weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
|
||||
auto weights = ConstantLiteral(&builder, *weights_literal);
|
||||
weights_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
|
||||
auto weights = ConstantLiteral(&builder, weights_literal);
|
||||
|
||||
auto expected_flat = LiteralUtil::CreateR1<float>({10});
|
||||
auto expected_literal =
|
||||
expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
|
||||
expected_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
|
||||
|
||||
auto mirrored_weights = Rev(weights, {2, 3, 4});
|
||||
ConvWithGeneralPadding(gradients, mirrored_weights,
|
||||
/*window_strides=*/{1, 1, 1},
|
||||
/*padding=*/{{0, 0}, {0, 0}, {1, 1}});
|
||||
ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_);
|
||||
ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
|
||||
@ -1359,17 +1359,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
|
||||
|
||||
auto activations_flat = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
|
||||
auto activations_literal =
|
||||
activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie();
|
||||
auto activations = ConstantLiteral(&builder, *activations_literal);
|
||||
activations_flat.Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie();
|
||||
auto activations = ConstantLiteral(&builder, activations_literal);
|
||||
|
||||
auto gradients_flat = LiteralUtil::CreateR1<float>({100, 10, 1});
|
||||
auto gradients_literal =
|
||||
gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
|
||||
auto gradients = ConstantLiteral(&builder, *gradients_literal);
|
||||
gradients_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
|
||||
auto gradients = ConstantLiteral(&builder, gradients_literal);
|
||||
|
||||
auto expected_flat = LiteralUtil::CreateR1<float>({13, 24, 130});
|
||||
auto expected_literal =
|
||||
expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
|
||||
expected_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
|
||||
|
||||
auto forward_conv =
|
||||
ConvGeneralDilated(activations, gradients,
|
||||
@ -1379,7 +1379,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
|
||||
XlaBuilder::CreateDefaultConvDimensionNumbers(
|
||||
/*num_spatial_dims=*/3));
|
||||
Transpose(forward_conv, {0, 1, 2, 3, 4});
|
||||
ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_);
|
||||
ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -40,16 +40,16 @@ class CopyOpTest : public HloTestBase {
|
||||
protected:
|
||||
void TestCopyOp(const Literal& literal) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(literal.CloneToUnique()));
|
||||
auto constant =
|
||||
builder.AddInstruction(HloInstruction::CreateConstant(literal.Clone()));
|
||||
builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
constant->shape(), HloOpcode::kCopy, constant));
|
||||
auto computation = builder.Build();
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(std::move(computation));
|
||||
|
||||
std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result));
|
||||
Literal result = ExecuteAndTransfer(std::move(module), {});
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
|
||||
}
|
||||
|
||||
void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3);
|
||||
@ -58,31 +58,30 @@ class CopyOpTest : public HloTestBase {
|
||||
};
|
||||
|
||||
XLA_TEST_F(CopyOpTest, CopyR0Bool) {
|
||||
TestCopyOp(*LiteralUtil::CreateR0<bool>(true));
|
||||
TestCopyOp(LiteralUtil::CreateR0<bool>(true));
|
||||
}
|
||||
|
||||
XLA_TEST_F(CopyOpTest, CopyR1S0U32) {
|
||||
TestCopyOp(*LiteralUtil::CreateR1<uint32>({}));
|
||||
TestCopyOp(LiteralUtil::CreateR1<uint32>({}));
|
||||
}
|
||||
|
||||
XLA_TEST_F(CopyOpTest, CopyR1S3U32) {
|
||||
TestCopyOp(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
|
||||
TestCopyOp(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
|
||||
}
|
||||
|
||||
XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) {
|
||||
TestCopyOp(
|
||||
*LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
|
||||
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
|
||||
TestCopyOp(LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
|
||||
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
|
||||
}
|
||||
|
||||
XLA_TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) {
|
||||
TestCopyOp(*LiteralUtil::CreateR4(
|
||||
TestCopyOp(LiteralUtil::CreateR4(
|
||||
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
|
||||
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
|
||||
}
|
||||
|
||||
XLA_TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) {
|
||||
TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
|
||||
TestCopyOp(LiteralUtil::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
|
||||
}
|
||||
|
||||
XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
|
||||
@ -90,7 +89,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
|
||||
|
||||
// Copy literal to device to use as parameter.
|
||||
auto literal = LiteralUtil::CreateR0<float>(42.0);
|
||||
Shape shape = literal->shape();
|
||||
Shape shape = literal.shape();
|
||||
|
||||
auto param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape, "param0"));
|
||||
@ -102,9 +101,8 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(std::move(computation));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
ExecuteAndTransfer(std::move(module), {literal.get()});
|
||||
LiteralTestUtil::ExpectR0Near<float>(42.0f, *result, error_spec_);
|
||||
Literal result = ExecuteAndTransfer(std::move(module), {&literal});
|
||||
LiteralTestUtil::ExpectR0Near<float>(42.0f, result, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) {
|
||||
@ -123,19 +121,17 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) {
|
||||
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(std::move(computation));
|
||||
std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
|
||||
LiteralTestUtil::ExpectR2Near<float>({{1.0, 2.0}, {3.0, 4.0}}, *result,
|
||||
Literal result = ExecuteAndTransfer(std::move(module), {});
|
||||
LiteralTestUtil::ExpectR2Near<float>({{1.0, 2.0}, {3.0, 4.0}}, result,
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
||||
std::unique_ptr<Literal> literal =
|
||||
LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
||||
Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
||||
// Reverse the minor-to-major order of the literal.
|
||||
Layout* literal_layout =
|
||||
literal->mutable_shape_do_not_use()->mutable_layout();
|
||||
Layout* literal_layout = literal.mutable_shape_do_not_use()->mutable_layout();
|
||||
ASSERT_EQ(2, literal_layout->minor_to_major_size());
|
||||
literal_layout->mutable_minor_to_major()->SwapElements(0, 1);
|
||||
|
||||
@ -149,11 +145,11 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
|
||||
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(std::move(computation));
|
||||
std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
|
||||
Literal result = ExecuteAndTransfer(std::move(module), {});
|
||||
|
||||
// The result of the computation has the default layout, which is the inverse
|
||||
// of the layout of the source literal.
|
||||
LiteralTestUtil::ExpectR2Near<float>({{1.0, 3.0}, {2.0, 4.0}}, *result,
|
||||
LiteralTestUtil::ExpectR2Near<float>({{1.0, 3.0}, {2.0, 4.0}}, result,
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
@ -169,7 +165,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
|
||||
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
||||
std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(a);
|
||||
Literal literal = LiteralUtil::CreateR3FromArray3D(a);
|
||||
|
||||
HloInstruction* constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(literal)));
|
||||
@ -182,9 +178,9 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(std::move(computation));
|
||||
ForceResultLayout(module.get(), LayoutUtil::MakeLayout({1, 2, 0}));
|
||||
std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
|
||||
Literal result = ExecuteAndTransfer(std::move(module), {});
|
||||
|
||||
LiteralTestUtil::ExpectR3EqualArray3D(a, *result);
|
||||
LiteralTestUtil::ExpectR3EqualArray3D(a, result);
|
||||
}
|
||||
|
||||
void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
|
||||
@ -203,7 +199,7 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
|
||||
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
||||
std::unique_ptr<Literal> literal = LiteralUtil::CreateR4FromArray4D(a);
|
||||
Literal literal = LiteralUtil::CreateR4FromArray4D(a);
|
||||
|
||||
HloInstruction* constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(literal)));
|
||||
@ -216,9 +212,9 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(std::move(computation));
|
||||
ForceResultLayout(module.get(), LayoutUtil::MakeLayout(permutation));
|
||||
std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
|
||||
Literal result = ExecuteAndTransfer(std::move(module), {});
|
||||
|
||||
LiteralTestUtil::ExpectR4EqualArray4D(a, *result);
|
||||
LiteralTestUtil::ExpectR4EqualArray4D(a, result);
|
||||
}
|
||||
|
||||
XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_SingleIncompleteTilePerLayer) {
|
||||
@ -250,11 +246,11 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) {
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
Parameter(&builder, 0, in_shape, "input");
|
||||
auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie();
|
||||
auto input_data = client_->TransferToServer(empty).ConsumeValueOrDie();
|
||||
|
||||
auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape)
|
||||
.ConsumeValueOrDie();
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual));
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(empty, actual));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -46,7 +46,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) {
|
||||
auto module =
|
||||
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
|
||||
auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
|
||||
EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()}));
|
||||
EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal}));
|
||||
}
|
||||
|
||||
XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
|
||||
@ -68,9 +68,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
|
||||
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
|
||||
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
|
||||
auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
|
||||
EXPECT_EQ(
|
||||
*LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
|
||||
*ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()}));
|
||||
EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}),
|
||||
ExecuteAndTransfer(std::move(module), {&literal0, &literal1}));
|
||||
}
|
||||
|
||||
// On the GPU backend, constants get special handling. Someone might pass a
|
||||
@ -95,8 +94,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) {
|
||||
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
|
||||
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
|
||||
auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
|
||||
EXPECT_EQ(*LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
|
||||
*ExecuteAndTransfer(std::move(module), {literal0.get()}));
|
||||
EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}),
|
||||
ExecuteAndTransfer(std::move(module), {&literal0}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -80,8 +80,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) {
|
||||
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
|
||||
LiteralTestUtil::ExpectR0Near<float>(44.0f, *result, error_spec_);
|
||||
Literal result = ExecuteAndTransfer(std::move(module), {});
|
||||
LiteralTestUtil::ExpectR0Near<float>(44.0f, result, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
|
||||
@ -101,8 +101,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
|
||||
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
|
||||
LiteralTestUtil::ExpectR0Near<float>(10.0f, *result, error_spec_);
|
||||
Literal result = ExecuteAndTransfer(std::move(module), {});
|
||||
LiteralTestUtil::ExpectR0Near<float>(10.0f, result, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(CustomCallTest,
|
||||
@ -125,9 +125,9 @@ XLA_TEST_F(CustomCallTest,
|
||||
|
||||
module->AddEntryComputation(b.Build());
|
||||
|
||||
std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
|
||||
Literal result = ExecuteAndTransfer(std::move(module), {});
|
||||
LiteralTestUtil::ExpectR3EqualArray3D<float>(
|
||||
Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, *result);
|
||||
Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result);
|
||||
}
|
||||
|
||||
class CustomCallClientAPITest : public ClientLibraryTestBase {};
|
||||
|
@ -64,11 +64,11 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) {
|
||||
|
||||
// Try copying the elements back and comparing it
|
||||
auto handles = result_status.ConsumeValueOrDie();
|
||||
std::unique_ptr<Literal> literal;
|
||||
Literal literal;
|
||||
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
|
||||
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
|
||||
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
|
||||
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
|
||||
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
|
||||
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
|
||||
}
|
||||
|
||||
TEST_F(DeconstructTupleTest, DeconstructTupleTwice) {
|
||||
@ -86,19 +86,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) {
|
||||
auto handles1 = result_status1.ConsumeValueOrDie();
|
||||
auto handles2 = result_status2.ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<Literal> literal;
|
||||
Literal literal;
|
||||
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[0]));
|
||||
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
|
||||
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
|
||||
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[1]));
|
||||
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
|
||||
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
|
||||
|
||||
handles1[0].reset();
|
||||
handles1[1].reset();
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[0]));
|
||||
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
|
||||
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
|
||||
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[1]));
|
||||
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
|
||||
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
|
||||
}
|
||||
|
||||
XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) {
|
||||
@ -116,15 +116,15 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) {
|
||||
// the same as handle[3] and handle[1] should be the same as handle[2].
|
||||
auto handles = result_status.ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<Literal> literal;
|
||||
Literal literal;
|
||||
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
|
||||
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
|
||||
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
|
||||
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
|
||||
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
|
||||
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
|
||||
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
|
||||
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
|
||||
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
|
||||
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[3]));
|
||||
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
|
||||
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
|
||||
}
|
||||
|
||||
TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) {
|
||||
@ -142,19 +142,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) {
|
||||
// should not have been deallocated because of reference counting.
|
||||
global_data.reset();
|
||||
|
||||
std::unique_ptr<Literal> literal;
|
||||
Literal literal;
|
||||
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
|
||||
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
|
||||
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
|
||||
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
|
||||
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
|
||||
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
|
||||
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
|
||||
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
|
||||
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
|
||||
|
||||
/// Try deallocating one of the repeated elements, then copy
|
||||
handles[0].reset();
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
|
||||
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
|
||||
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
|
||||
}
|
||||
|
||||
TEST_F(DeconstructTupleTest, DeconstructNonTuple) {
|
||||
@ -170,10 +170,9 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) {
|
||||
|
||||
XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) {
|
||||
XlaBuilder builder(TestName());
|
||||
std::unique_ptr<Literal> param0_literal =
|
||||
LiteralUtil::CreateR1<float>({3.14f, -100.25f});
|
||||
Literal param0_literal = LiteralUtil::CreateR1<float>({3.14f, -100.25f});
|
||||
std::unique_ptr<GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
|
||||
auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0");
|
||||
Tuple(&builder, {p});
|
||||
auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()});
|
||||
|
@ -68,16 +68,16 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) {
|
||||
XlaOp param;
|
||||
auto param_data = CreateParameterAndTransferLiteral(
|
||||
0,
|
||||
*LiteralUtil::MakeTuple(
|
||||
{LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}).get(),
|
||||
LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}}).get()}),
|
||||
LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}),
|
||||
LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}})}),
|
||||
"arg0", &builder, ¶m);
|
||||
auto lhs = GetTupleElement(param, 0);
|
||||
auto rhs = GetTupleElement(param, 1);
|
||||
Dot(lhs, rhs);
|
||||
|
||||
ComputeAndCompareLiteral(&builder,
|
||||
*LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
|
||||
LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
|
||||
{param_data.get()});
|
||||
}
|
||||
|
||||
@ -196,11 +196,11 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) {
|
||||
|
||||
auto lhs_handle =
|
||||
this->client_
|
||||
->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
|
||||
->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
|
||||
{{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}}))
|
||||
.ConsumeValueOrDie();
|
||||
auto rhs_handle = this->client_
|
||||
->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
|
||||
->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
|
||||
{{1.0f}, {2.0f}, {3.0f}, {4.0f}}))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
@ -219,14 +219,14 @@ class SquareMatrixDot : public DotOperationTest {
|
||||
void TestImpl(bool lhs_row_major, bool rhs_row_major) {
|
||||
auto lhs_handle =
|
||||
client_
|
||||
->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
|
||||
->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
|
||||
{{1.0f, 2.0f}, {3.0f, -4.0f}},
|
||||
LayoutUtil::MakeLayout(
|
||||
MinorToMajorForIsRowMajor(lhs_row_major))))
|
||||
.ConsumeValueOrDie();
|
||||
auto rhs_handle =
|
||||
client_
|
||||
->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
|
||||
->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
|
||||
{{1.0f, 6.0f}, {7.0f, -4.0f}},
|
||||
LayoutUtil::MakeLayout(
|
||||
MinorToMajorForIsRowMajor(rhs_row_major))))
|
||||
@ -286,24 +286,23 @@ void ParametricDotTest::TestImpl() {
|
||||
|
||||
std::unique_ptr<Array2D<NativeT>> dot_lhs_data =
|
||||
MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.k);
|
||||
std::unique_ptr<Literal> dot_lhs_lit =
|
||||
LiteralUtil::CreateR2FromArray2DWithLayout(
|
||||
*dot_lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(
|
||||
param.dot_lhs_row_major)));
|
||||
Literal dot_lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
|
||||
*dot_lhs_data, LayoutUtil::MakeLayout(
|
||||
MinorToMajorForIsRowMajor(param.dot_lhs_row_major)));
|
||||
std::unique_ptr<GlobalData> dot_lhs_handle =
|
||||
client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie();
|
||||
client_->TransferToServer(dot_lhs_lit).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<Array2D<NativeT>> dot_rhs_data =
|
||||
MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.k, param.n);
|
||||
Layout rhs_layout = LayoutUtil::MakeLayout(
|
||||
MinorToMajorForIsRowMajor(param.dot_rhs_row_major));
|
||||
std::unique_ptr<Literal> dot_rhs_lit =
|
||||
Literal dot_rhs_lit =
|
||||
LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout);
|
||||
std::unique_ptr<GlobalData> dot_rhs_handle =
|
||||
client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie();
|
||||
client_->TransferToServer(dot_rhs_lit).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<Array2D<NativeT>> addend_data;
|
||||
std::unique_ptr<Literal> addend_lit;
|
||||
Literal addend_lit;
|
||||
std::unique_ptr<GlobalData> addend_handle;
|
||||
|
||||
if (param.has_addend) {
|
||||
@ -311,7 +310,7 @@ void ParametricDotTest::TestImpl() {
|
||||
addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
|
||||
*addend_data, LayoutUtil::MakeLayout(
|
||||
MinorToMajorForIsRowMajor(param.addend_row_major)));
|
||||
addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie();
|
||||
addend_handle = client_->TransferToServer(addend_lit).ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
@ -477,14 +476,14 @@ class NonsquareMatrixDot : public DotOperationTest {
|
||||
void TestImpl(bool lhs_row_major, bool rhs_row_major) {
|
||||
auto lhs_handle =
|
||||
client_
|
||||
->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
|
||||
->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
|
||||
{{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}},
|
||||
LayoutUtil::MakeLayout(
|
||||
MinorToMajorForIsRowMajor(lhs_row_major))))
|
||||
.ConsumeValueOrDie();
|
||||
auto rhs_handle =
|
||||
client_
|
||||
->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
|
||||
->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
|
||||
{{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}},
|
||||
LayoutUtil::MakeLayout(
|
||||
MinorToMajorForIsRowMajor(rhs_row_major))))
|
||||
@ -511,12 +510,12 @@ XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); }
|
||||
XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
|
||||
auto lhs_handle =
|
||||
client_
|
||||
->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
|
||||
->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
|
||||
{{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
|
||||
.ConsumeValueOrDie();
|
||||
auto rhs_handle =
|
||||
client_
|
||||
->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
|
||||
->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
|
||||
{{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
|
||||
LayoutUtil::MakeLayout({1, 0})))
|
||||
.ConsumeValueOrDie();
|
||||
@ -584,7 +583,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
|
||||
Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2});
|
||||
|
||||
auto x_data = this->client_
|
||||
->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
|
||||
->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
|
||||
{{{{1000.0f, 100.0f}, {10.0f, 1.0f}},
|
||||
{{2000.0f, 200.0f}, {20.0f, 2.0f}}},
|
||||
{{{3000.0f, 300.0f}, {30.0f, 3.0f}},
|
||||
@ -592,7 +591,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
|
||||
.ConsumeValueOrDie();
|
||||
auto y_data =
|
||||
this->client_
|
||||
->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
|
||||
->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
|
||||
{{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
|
||||
{{{11.0f, 22.0f}, {33.0f, 44.0f}},
|
||||
{{55.0f, 66.0f}, {77.0f, 88.0f}}}}))
|
||||
@ -630,13 +629,13 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) {
|
||||
|
||||
auto x_data =
|
||||
this->client_
|
||||
->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
|
||||
->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
|
||||
{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
auto y_data =
|
||||
this->client_
|
||||
->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
|
||||
->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
|
||||
{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
@ -668,7 +667,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
|
||||
|
||||
auto x_data =
|
||||
this->client_
|
||||
->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
|
||||
->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
|
||||
{{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
|
||||
{{{9.0f, 10.0f}, {11.0f, 12.0f}},
|
||||
{{13.0f, 14.0f}, {15.0f, 16.0f}}}}))
|
||||
@ -676,7 +675,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
|
||||
|
||||
auto y_data =
|
||||
this->client_
|
||||
->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
|
||||
->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
|
||||
{{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}},
|
||||
{{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}}))
|
||||
.ConsumeValueOrDie();
|
||||
@ -708,14 +707,14 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) {
|
||||
auto lhs_handle =
|
||||
this->client_
|
||||
->TransferToServer(
|
||||
*LiteralUtil::CreateR2FromArray2DWithLayout<T>(
|
||||
LiteralUtil::CreateR2FromArray2DWithLayout<T>(
|
||||
*lhs, LayoutUtil::MakeLayout(
|
||||
MinorToMajorForIsRowMajor(row_major))))
|
||||
.ConsumeValueOrDie();
|
||||
auto rhs_handle =
|
||||
this->client_
|
||||
->TransferToServer(
|
||||
*LiteralUtil::CreateR2FromArray2DWithLayout<T>(
|
||||
LiteralUtil::CreateR2FromArray2DWithLayout<T>(
|
||||
*rhs, LayoutUtil::MakeLayout(
|
||||
MinorToMajorForIsRowMajor(row_major))))
|
||||
.ConsumeValueOrDie();
|
||||
@ -778,15 +777,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto arg_0_value,
|
||||
this->client_->TransferToServer(
|
||||
*LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
|
||||
LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto arg_1_value,
|
||||
this->client_->TransferToServer(
|
||||
*LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
|
||||
LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto arg_2_value,
|
||||
this->client_->TransferToServer(
|
||||
*LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
|
||||
LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
|
||||
|
||||
Array2D<T> expected({{53.0f, 74.0f}, {45.0f, 66.0f}});
|
||||
this->template ComputeAndCompareR2<T>(
|
||||
@ -827,15 +826,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto arg_0_value,
|
||||
this->client_->TransferToServer(
|
||||
*LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
|
||||
LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto arg_1_value,
|
||||
this->client_->TransferToServer(
|
||||
*LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
|
||||
LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto arg_2_value,
|
||||
this->client_->TransferToServer(
|
||||
*LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
|
||||
LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
|
||||
|
||||
Array2D<T> expected({{38.0f, 36.0f}, {93.0f, 91.0f}});
|
||||
this->template ComputeAndCompareR2<T>(
|
||||
|
@ -124,13 +124,13 @@ class DynamicSliceTest : public ClientLibraryTestBase {
|
||||
// vector<bool> is special so that it cannot be a Span<bool>, which
|
||||
// is what the code below wants. So instead we do this.
|
||||
Literal input_values =
|
||||
std::move(*LiteralUtil::CreateR1(input_values_int)
|
||||
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
LiteralUtil::CreateR1(input_values_int)
|
||||
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie();
|
||||
Literal expected_values =
|
||||
std::move(*LiteralUtil::CreateR1(expected_values_int)
|
||||
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
std::move(LiteralUtil::CreateR1(expected_values_int)
|
||||
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
// Initialize and transfer dynamic slice start indices parameter.
|
||||
@ -150,13 +150,13 @@ class DynamicSliceTest : public ClientLibraryTestBase {
|
||||
const std::vector<int64>& slice_sizes,
|
||||
const Array2D<int>& expected_values_int) {
|
||||
Literal input_values =
|
||||
std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
|
||||
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
|
||||
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
Literal expected_values =
|
||||
std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
|
||||
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
|
||||
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
// Initialize and transfer dynamic slice start indices parameter.
|
||||
@ -176,13 +176,13 @@ class DynamicSliceTest : public ClientLibraryTestBase {
|
||||
const std::vector<int64>& slice_sizes,
|
||||
const Array3D<int>& expected_values_int) {
|
||||
Literal input_values =
|
||||
std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
|
||||
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
|
||||
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
Literal expected_values =
|
||||
std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
|
||||
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
|
||||
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
// Initialize and transfer dynamic slice start indices parameter.
|
||||
@ -359,17 +359,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
|
||||
void RunR0(int input_value_int, int update_value_int,
|
||||
const std::vector<IndexT> slice_starts, int expected_value_int) {
|
||||
Literal input_value =
|
||||
std::move(*LiteralUtil::CreateR0(input_value_int)
|
||||
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
std::move(LiteralUtil::CreateR0(input_value_int)
|
||||
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
Literal update_value =
|
||||
std::move(*LiteralUtil::CreateR0(update_value_int)
|
||||
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
std::move(LiteralUtil::CreateR0(update_value_int)
|
||||
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
Literal expected_value =
|
||||
std::move(*LiteralUtil::CreateR0(expected_value_int)
|
||||
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
std::move(LiteralUtil::CreateR0(expected_value_int)
|
||||
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
// Initialize and transfer dynamic slice start indices parameter.
|
||||
@ -390,17 +390,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
|
||||
const std::vector<IndexT> slice_starts,
|
||||
absl::Span<const int> expected_values_int) {
|
||||
Literal input_values =
|
||||
std::move(*LiteralUtil::CreateR1(input_values_int)
|
||||
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
std::move(LiteralUtil::CreateR1(input_values_int)
|
||||
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
Literal update_values =
|
||||
std::move(*LiteralUtil::CreateR1(update_values_int)
|
||||
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
std::move(LiteralUtil::CreateR1(update_values_int)
|
||||
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
Literal expected_values =
|
||||
std::move(*LiteralUtil::CreateR1(expected_values_int)
|
||||
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
std::move(LiteralUtil::CreateR1(expected_values_int)
|
||||
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
// Initialize and transfer dynamic slice start indices parameter.
|
||||
@ -421,17 +421,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
|
||||
const std::vector<IndexT> slice_starts,
|
||||
const Array2D<int>& expected_values_int) {
|
||||
Literal input_values =
|
||||
std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
|
||||
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
|
||||
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
Literal update_values =
|
||||
std::move(*LiteralUtil::CreateR2FromArray2D(update_values_int)
|
||||
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
std::move(LiteralUtil::CreateR2FromArray2D(update_values_int)
|
||||
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
Literal expected_values =
|
||||
std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
|
||||
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
|
||||
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
// Initialize and transfer dynamic slice start indices parameter.
|
||||
@ -452,17 +452,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
|
||||
const std::vector<IndexT> slice_starts,
|
||||
const Array3D<int>& expected_values_int) {
|
||||
Literal input_values =
|
||||
std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
|
||||
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
|
||||
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
Literal update_values =
|
||||
std::move(*LiteralUtil::CreateR3FromArray3D(update_values_int)
|
||||
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
std::move(LiteralUtil::CreateR3FromArray3D(update_values_int)
|
||||
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
Literal expected_values =
|
||||
std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
|
||||
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
|
||||
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
|
||||
.ValueOrDie());
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
// Initialize and transfer dynamic slice start indices parameter.
|
||||
@ -529,9 +529,8 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
|
||||
|
||||
template <typename NativeT>
|
||||
void DumpArray(const string& name, const Array3D<NativeT> values) {
|
||||
std::unique_ptr<Literal> literal =
|
||||
LiteralUtil::CreateR3FromArray3D<NativeT>(values);
|
||||
LOG(INFO) << name << ":" << literal->ToString();
|
||||
Literal literal = LiteralUtil::CreateR3FromArray3D<NativeT>(values);
|
||||
LOG(INFO) << name << ":" << literal.ToString();
|
||||
}
|
||||
};
|
||||
|
||||
@ -719,7 +718,7 @@ void BM_DynamicSlice(int num_iters) {
|
||||
auto input_literal = LiteralUtil::CreateR4(
|
||||
{{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
|
||||
{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
|
||||
auto input = ConstantLiteral(&builder, *input_literal);
|
||||
auto input = ConstantLiteral(&builder, input_literal);
|
||||
|
||||
// Create dynamic slice start indices as a parameter: shape [4]
|
||||
auto start_indices_shape = ShapeUtil::MakeShape(S32, {4});
|
||||
@ -740,7 +739,7 @@ void BM_DynamicSlice(int num_iters) {
|
||||
auto stream =
|
||||
client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
|
||||
ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
|
||||
stream.get(), *start_indices_literal, buffer));
|
||||
stream.get(), start_indices_literal, buffer));
|
||||
|
||||
std::unique_ptr<LocalExecutable> executable =
|
||||
client
|
||||
|
@ -31,7 +31,7 @@ XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<GlobalData> input,
|
||||
client_->TransferToServer(
|
||||
*LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256)));
|
||||
LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256)));
|
||||
|
||||
XlaBuilder b(TestName() + ".add");
|
||||
Dot(Parameter(&b, 0, shape, "param_0"), Parameter(&b, 1, shape, "param_1"));
|
||||
|
@ -38,7 +38,7 @@ class ExhaustiveF32ElementwiseOpTest
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
std::unique_ptr<Literal> input_literal =
|
||||
Literal input_literal =
|
||||
LiteralUtil::CreateFromDimensions(F32, {input_size});
|
||||
for (int64 i = begin; i < end; i++) {
|
||||
if (i >= known_incorrect_range.first &&
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user