Global de-std::unique_ptr cleanup for xla::Literal.
PiperOrigin-RevId: 212313258
This commit is contained in:
		
							parent
							
								
									656b3e9c84
								
							
						
					
					
						commit
						dd6d7c5c58
					
				@ -81,7 +81,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
 | 
			
		||||
      TF_ASSIGN_OR_RETURN(auto literal,
 | 
			
		||||
                          client->ComputeConstant(constant_graph));
 | 
			
		||||
      TF_RETURN_IF_ERROR(
 | 
			
		||||
          LiteralToHostTensor(*literal, arg.type, &arg.constant_value));
 | 
			
		||||
          LiteralToHostTensor(literal, arg.type, &arg.constant_value));
 | 
			
		||||
    } else {
 | 
			
		||||
      arg.kind = XlaCompiler::Argument::kParameter;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -78,14 +78,14 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
 | 
			
		||||
    std::vector<xla::XlaOp> args;
 | 
			
		||||
    args.push_back(ctx->Input(0));
 | 
			
		||||
    args.push_back(xla::ConstantLiteral(
 | 
			
		||||
        &b, *xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
 | 
			
		||||
        &b, xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
 | 
			
		||||
    if (input_shape.dims() > 1) {
 | 
			
		||||
      // Don't bother passing the output shape and dim for the 1d case, since
 | 
			
		||||
      // the shape is always a scalar and the dim is always 0.
 | 
			
		||||
      args.push_back(xla::ConstantLiteral(
 | 
			
		||||
          &b, *xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
 | 
			
		||||
          &b, xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
 | 
			
		||||
      args.push_back(
 | 
			
		||||
          xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0<int32>(dim)));
 | 
			
		||||
          xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0<int32>(dim)));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    xla::Shape xla_shape =
 | 
			
		||||
 | 
			
		||||
@ -64,31 +64,31 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
 | 
			
		||||
  xla::Literal literal;
 | 
			
		||||
  switch (type) {
 | 
			
		||||
    case xla::U8:
 | 
			
		||||
      literal = std::move(*xla::LiteralUtil::CreateR0<uint8>(value));
 | 
			
		||||
      literal = xla::LiteralUtil::CreateR0<uint8>(value);
 | 
			
		||||
      break;
 | 
			
		||||
    case xla::U32:
 | 
			
		||||
      literal = std::move(*xla::LiteralUtil::CreateR0<uint32>(value));
 | 
			
		||||
      literal = xla::LiteralUtil::CreateR0<uint32>(value);
 | 
			
		||||
      break;
 | 
			
		||||
    case xla::U64:
 | 
			
		||||
      literal = std::move(*xla::LiteralUtil::CreateR0<uint64>(value));
 | 
			
		||||
      literal = xla::LiteralUtil::CreateR0<uint64>(value);
 | 
			
		||||
      break;
 | 
			
		||||
    case xla::S8:
 | 
			
		||||
      literal = std::move(*xla::LiteralUtil::CreateR0<int8>(value));
 | 
			
		||||
      literal = xla::LiteralUtil::CreateR0<int8>(value);
 | 
			
		||||
      break;
 | 
			
		||||
    case xla::S32:
 | 
			
		||||
      literal = std::move(*xla::LiteralUtil::CreateR0<int32>(value));
 | 
			
		||||
      literal = xla::LiteralUtil::CreateR0<int32>(value);
 | 
			
		||||
      break;
 | 
			
		||||
    case xla::S64:
 | 
			
		||||
      literal = std::move(*xla::LiteralUtil::CreateR0<int64>(value));
 | 
			
		||||
      literal = xla::LiteralUtil::CreateR0<int64>(value);
 | 
			
		||||
      break;
 | 
			
		||||
    case xla::F32:
 | 
			
		||||
      literal = std::move(*xla::LiteralUtil::CreateR0<float>(value));
 | 
			
		||||
      literal = xla::LiteralUtil::CreateR0<float>(value);
 | 
			
		||||
      break;
 | 
			
		||||
    case xla::F64:
 | 
			
		||||
      literal = std::move(*xla::LiteralUtil::CreateR0<double>(value));
 | 
			
		||||
      literal = xla::LiteralUtil::CreateR0<double>(value);
 | 
			
		||||
      break;
 | 
			
		||||
    case xla::C64:
 | 
			
		||||
      literal = std::move(*xla::LiteralUtil::CreateR0<complex64>(value));
 | 
			
		||||
      literal = xla::LiteralUtil::CreateR0<complex64>(value);
 | 
			
		||||
      break;
 | 
			
		||||
    case xla::PRED:
 | 
			
		||||
      LOG(FATAL) << "pred element type is not integral";
 | 
			
		||||
@ -96,12 +96,12 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
 | 
			
		||||
    case xla::U16:
 | 
			
		||||
      LOG(FATAL) << "u16/s16 literals not yet implemented";
 | 
			
		||||
    case xla::BF16:
 | 
			
		||||
      literal = std::move(
 | 
			
		||||
          *xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value)));
 | 
			
		||||
      literal =
 | 
			
		||||
          xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value));
 | 
			
		||||
      break;
 | 
			
		||||
    case xla::F16:
 | 
			
		||||
      literal = std::move(*xla::LiteralUtil::CreateR0<xla::half>(
 | 
			
		||||
          static_cast<xla::half>(value)));
 | 
			
		||||
      literal =
 | 
			
		||||
          xla::LiteralUtil::CreateR0<xla::half>(static_cast<xla::half>(value));
 | 
			
		||||
      break;
 | 
			
		||||
    case xla::TUPLE:
 | 
			
		||||
      LOG(FATAL) << "tuple element type is not integral";
 | 
			
		||||
 | 
			
		||||
@ -27,19 +27,17 @@ TEST(LiteralUtil, LiteralToHostTensor) {
 | 
			
		||||
  // int64 literal can only be converted to an int64 host tensor.
 | 
			
		||||
  {
 | 
			
		||||
    std::vector<int64> int64_values = {1, 2, 3};
 | 
			
		||||
    std::unique_ptr<xla::Literal> int64_values_literal =
 | 
			
		||||
    xla::Literal int64_values_literal =
 | 
			
		||||
        xla::LiteralUtil::CreateR1(absl::Span<const int64>(int64_values));
 | 
			
		||||
    Tensor host_tensor;
 | 
			
		||||
    EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
 | 
			
		||||
              LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor)
 | 
			
		||||
              LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor)
 | 
			
		||||
                  .error_message());
 | 
			
		||||
    EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32",
 | 
			
		||||
              LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor)
 | 
			
		||||
                  .error_message());
 | 
			
		||||
    EXPECT_EQ(
 | 
			
		||||
        "Cannot convert literal of type S64 to tensor of type qint32",
 | 
			
		||||
        LiteralToHostTensor(*int64_values_literal, DT_QINT32, &host_tensor)
 | 
			
		||||
            .error_message());
 | 
			
		||||
    EXPECT_TRUE(
 | 
			
		||||
        LiteralToHostTensor(*int64_values_literal, DT_INT64, &host_tensor)
 | 
			
		||||
            .ok());
 | 
			
		||||
        LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok());
 | 
			
		||||
    test::ExpectTensorEqual<int64>(host_tensor,
 | 
			
		||||
                                   test::AsTensor<int64>(int64_values));
 | 
			
		||||
  }
 | 
			
		||||
@ -48,23 +46,22 @@ TEST(LiteralUtil, LiteralToHostTensor) {
 | 
			
		||||
    // Repeat tests with int32.
 | 
			
		||||
    Tensor host_tensor;
 | 
			
		||||
    std::vector<int32> int32_values = {10, 11};
 | 
			
		||||
    std::unique_ptr<xla::Literal> int32_values_literal =
 | 
			
		||||
    xla::Literal int32_values_literal =
 | 
			
		||||
        xla::LiteralUtil::CreateR1(absl::Span<const int32>(int32_values));
 | 
			
		||||
    EXPECT_TRUE(
 | 
			
		||||
        LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor)
 | 
			
		||||
            .ok());
 | 
			
		||||
        LiteralToHostTensor(int32_values_literal, DT_INT32, &host_tensor).ok());
 | 
			
		||||
    test::ExpectTensorEqual<int32>(host_tensor,
 | 
			
		||||
                                   test::AsTensor<int32>(int32_values));
 | 
			
		||||
 | 
			
		||||
    EXPECT_TRUE(
 | 
			
		||||
        LiteralToHostTensor(*int32_values_literal, DT_QINT32, &host_tensor)
 | 
			
		||||
        LiteralToHostTensor(int32_values_literal, DT_QINT32, &host_tensor)
 | 
			
		||||
            .ok());
 | 
			
		||||
    std::vector<qint32> qint32_values = {10, 11};
 | 
			
		||||
    test::ExpectTensorEqual<qint32>(host_tensor,
 | 
			
		||||
                                    test::AsTensor<qint32>(qint32_values));
 | 
			
		||||
 | 
			
		||||
    EXPECT_EQ("Cannot convert literal of type S32 to tensor of type int64",
 | 
			
		||||
              LiteralToHostTensor(*int32_values_literal, DT_INT64, &host_tensor)
 | 
			
		||||
              LiteralToHostTensor(int32_values_literal, DT_INT64, &host_tensor)
 | 
			
		||||
                  .error_message());
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -77,8 +77,8 @@ TEST(ConvertGraphDefToXla, Sum) {
 | 
			
		||||
  // Set up arguments.
 | 
			
		||||
  auto x_literal = xla::LiteralUtil::CreateR0<int32>(10);
 | 
			
		||||
  auto y_literal = xla::LiteralUtil::CreateR0<int32>(32);
 | 
			
		||||
  auto x_global_or = client->TransferToServer(*x_literal);
 | 
			
		||||
  auto y_global_or = client->TransferToServer(*y_literal);
 | 
			
		||||
  auto x_global_or = client->TransferToServer(x_literal);
 | 
			
		||||
  auto y_global_or = client->TransferToServer(y_literal);
 | 
			
		||||
  TF_EXPECT_OK(x_global_or.status());
 | 
			
		||||
  TF_EXPECT_OK(y_global_or.status());
 | 
			
		||||
  std::unique_ptr<xla::GlobalData> x_global =
 | 
			
		||||
@ -90,8 +90,8 @@ TEST(ConvertGraphDefToXla, Sum) {
 | 
			
		||||
  auto result_or =
 | 
			
		||||
      client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()});
 | 
			
		||||
  TF_EXPECT_OK(result_or.status());
 | 
			
		||||
  std::unique_ptr<xla::Literal> result = std::move(result_or.ValueOrDie());
 | 
			
		||||
  EXPECT_EQ("(s32[]) (\n42\n)", result->ToString());
 | 
			
		||||
  xla::Literal result = std::move(result_or.ValueOrDie());
 | 
			
		||||
  EXPECT_EQ("(s32[]) (\n42\n)", result.ToString());
 | 
			
		||||
 | 
			
		||||
  config.mutable_feed(0)->mutable_id()->set_output_index(
 | 
			
		||||
      123); /* invalid output_index */
 | 
			
		||||
 | 
			
		||||
@ -208,27 +208,22 @@ TEST_F(XlaCompilerTest, Simple) {
 | 
			
		||||
                                     std::move(graph), args, &result));
 | 
			
		||||
 | 
			
		||||
  // Tests that the generated computation works.
 | 
			
		||||
  std::unique_ptr<xla::Literal> param0_literal =
 | 
			
		||||
      xla::LiteralUtil::CreateR1<int32>({7, 42});
 | 
			
		||||
  std::unique_ptr<xla::Literal> param1_literal =
 | 
			
		||||
      xla::LiteralUtil::CreateR1<int32>({-3, 101});
 | 
			
		||||
  xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
 | 
			
		||||
  xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
 | 
			
		||||
  std::unique_ptr<xla::GlobalData> param0_data =
 | 
			
		||||
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<xla::GlobalData> param1_data =
 | 
			
		||||
      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(param1_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<xla::GlobalData> actual =
 | 
			
		||||
      client_
 | 
			
		||||
          ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<xla::Literal> actual_literal =
 | 
			
		||||
      client_->Transfer(*actual).ConsumeValueOrDie();
 | 
			
		||||
  xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<xla::Literal> expected0 =
 | 
			
		||||
      xla::LiteralUtil::CreateR1<int32>({4, 143});
 | 
			
		||||
  std::unique_ptr<xla::Literal> expected_literal =
 | 
			
		||||
      xla::LiteralUtil::MakeTuple({expected0.get()});
 | 
			
		||||
  EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
 | 
			
		||||
  xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({4, 143});
 | 
			
		||||
  xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
 | 
			
		||||
  EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Tests compilation of a graph where the _Retval node is not necessarily last
 | 
			
		||||
@ -264,23 +259,20 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) {
 | 
			
		||||
                                     args, &result));
 | 
			
		||||
 | 
			
		||||
  // Tests that the generated computation works.
 | 
			
		||||
  std::unique_ptr<xla::Literal> param0_literal =
 | 
			
		||||
      xla::LiteralUtil::CreateR1<int32>({7, 42});
 | 
			
		||||
  std::unique_ptr<xla::Literal> param1_literal =
 | 
			
		||||
      xla::LiteralUtil::CreateR1<int32>({-3, 101});
 | 
			
		||||
  xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
 | 
			
		||||
  xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
 | 
			
		||||
  std::unique_ptr<xla::GlobalData> param0_data =
 | 
			
		||||
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<xla::GlobalData> param1_data =
 | 
			
		||||
      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(param1_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<xla::GlobalData> actual =
 | 
			
		||||
      client_
 | 
			
		||||
          ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<xla::Literal> actual_literal =
 | 
			
		||||
      client_->Transfer(*actual).ConsumeValueOrDie();
 | 
			
		||||
  xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal));
 | 
			
		||||
  EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Tests that the compiler doesn't reorder the parameters.
 | 
			
		||||
@ -408,23 +400,19 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
 | 
			
		||||
    EXPECT_FALSE(result.outputs[1].is_constant);
 | 
			
		||||
 | 
			
		||||
    // Tests that the generated computation works.
 | 
			
		||||
    std::unique_ptr<xla::Literal> param0_literal =
 | 
			
		||||
        xla::LiteralUtil::CreateR1<int32>({7, 42});
 | 
			
		||||
    xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
 | 
			
		||||
    std::unique_ptr<xla::GlobalData> param0_data =
 | 
			
		||||
        client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
 | 
			
		||||
        client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
    std::unique_ptr<xla::GlobalData> actual =
 | 
			
		||||
        client_->Execute(*result.computation, {param0_data.get()})
 | 
			
		||||
            .ConsumeValueOrDie();
 | 
			
		||||
    std::unique_ptr<xla::Literal> actual_literal =
 | 
			
		||||
    xla::Literal actual_literal =
 | 
			
		||||
        client_->Transfer(*actual).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
    std::unique_ptr<xla::Literal> expected0 =
 | 
			
		||||
        xla::LiteralUtil::CreateR1<int32>({-7, -42});
 | 
			
		||||
    std::unique_ptr<xla::Literal> expected_literal =
 | 
			
		||||
        xla::LiteralUtil::MakeTuple({expected0.get()});
 | 
			
		||||
    EXPECT_TRUE(
 | 
			
		||||
        xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
 | 
			
		||||
    xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
 | 
			
		||||
    xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
 | 
			
		||||
    EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  {
 | 
			
		||||
@ -443,24 +431,21 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
 | 
			
		||||
    EXPECT_FALSE(result.outputs[1].is_constant);
 | 
			
		||||
 | 
			
		||||
    // Tests that the generated computation works.
 | 
			
		||||
    std::unique_ptr<xla::Literal> param0_literal =
 | 
			
		||||
        xla::LiteralUtil::CreateR1<int32>({7, 42});
 | 
			
		||||
    xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
 | 
			
		||||
    std::unique_ptr<xla::GlobalData> param0_data =
 | 
			
		||||
        client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
 | 
			
		||||
        client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
    std::unique_ptr<xla::GlobalData> actual =
 | 
			
		||||
        client_->Execute(*result.computation, {param0_data.get()})
 | 
			
		||||
            .ConsumeValueOrDie();
 | 
			
		||||
    std::unique_ptr<xla::Literal> actual_literal =
 | 
			
		||||
    xla::Literal actual_literal =
 | 
			
		||||
        client_->Transfer(*actual).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
    std::unique_ptr<xla::Literal> expected0 =
 | 
			
		||||
        xla::LiteralUtil::CreateR0<int32>(7);
 | 
			
		||||
    std::unique_ptr<xla::Literal> expected1 =
 | 
			
		||||
        xla::LiteralUtil::CreateR1<int32>({-7, -42});
 | 
			
		||||
    std::unique_ptr<xla::Literal> expected =
 | 
			
		||||
        xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
 | 
			
		||||
    EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal));
 | 
			
		||||
    xla::Literal expected0 = xla::LiteralUtil::CreateR0<int32>(7);
 | 
			
		||||
    xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
 | 
			
		||||
    xla::Literal expected =
 | 
			
		||||
        xla::LiteralUtil::MakeTuple({&expected0, &expected1});
 | 
			
		||||
    EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal));
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -672,34 +657,26 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
 | 
			
		||||
            update.tensor_array_gradients_accessed);
 | 
			
		||||
 | 
			
		||||
  // Tests that the generated computation works.
 | 
			
		||||
  std::unique_ptr<xla::Literal> input_base =
 | 
			
		||||
      xla::LiteralUtil::CreateR1<int32>({7, 42});
 | 
			
		||||
  std::unique_ptr<xla::Literal> input_grad2 =
 | 
			
		||||
      xla::LiteralUtil::CreateR1<int32>({-3, 101});
 | 
			
		||||
  std::unique_ptr<xla::Literal> input =
 | 
			
		||||
      xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()});
 | 
			
		||||
  xla::Literal input_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
 | 
			
		||||
  xla::Literal input_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
 | 
			
		||||
  xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2});
 | 
			
		||||
  std::unique_ptr<xla::GlobalData> param0_data =
 | 
			
		||||
      client_->TransferToServer(*input).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(input).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<xla::GlobalData> actual =
 | 
			
		||||
      client_->Execute(*result.computation, {param0_data.get()})
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<xla::Literal> actual_literal =
 | 
			
		||||
      client_->Transfer(*actual).ConsumeValueOrDie();
 | 
			
		||||
  xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<xla::Literal> output_read =
 | 
			
		||||
      xla::LiteralUtil::CreateR0<int32>(42);
 | 
			
		||||
  std::unique_ptr<xla::Literal> output_base =
 | 
			
		||||
      xla::LiteralUtil::CreateR1<int32>({7, 42});
 | 
			
		||||
  std::unique_ptr<xla::Literal> output_grad1 =
 | 
			
		||||
      xla::LiteralUtil::CreateR1<int32>({0, 1});
 | 
			
		||||
  std::unique_ptr<xla::Literal> output_grad2 =
 | 
			
		||||
      xla::LiteralUtil::CreateR1<int32>({-3, 101});
 | 
			
		||||
  std::unique_ptr<xla::Literal> output_resource = xla::LiteralUtil::MakeTuple(
 | 
			
		||||
      {output_base.get(), output_grad1.get(), output_grad2.get()});
 | 
			
		||||
  std::unique_ptr<xla::Literal> expected_literal =
 | 
			
		||||
      xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()});
 | 
			
		||||
  EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
 | 
			
		||||
  xla::Literal output_read = xla::LiteralUtil::CreateR0<int32>(42);
 | 
			
		||||
  xla::Literal output_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
 | 
			
		||||
  xla::Literal output_grad1 = xla::LiteralUtil::CreateR1<int32>({0, 1});
 | 
			
		||||
  xla::Literal output_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
 | 
			
		||||
  xla::Literal output_resource =
 | 
			
		||||
      xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2});
 | 
			
		||||
  xla::Literal expected_literal =
 | 
			
		||||
      xla::LiteralUtil::MakeTuple({&output_read, &output_resource});
 | 
			
		||||
  EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Tests compilation and execution of a graph that adds two tensors.
 | 
			
		||||
@ -866,29 +843,24 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
 | 
			
		||||
 | 
			
		||||
void RunAndCheckVariablesComputation(
 | 
			
		||||
    xla::Client* client, const XlaCompiler::CompilationResult& result) {
 | 
			
		||||
  std::unique_ptr<xla::Literal> param0_literal =
 | 
			
		||||
      xla::LiteralUtil::CreateR1<int32>({7, 42});
 | 
			
		||||
  std::unique_ptr<xla::Literal> param1_literal =
 | 
			
		||||
      xla::LiteralUtil::CreateR1<int32>({-3, 101});
 | 
			
		||||
  xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
 | 
			
		||||
  xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
 | 
			
		||||
  std::unique_ptr<xla::GlobalData> param0_data =
 | 
			
		||||
      client->TransferToServer(*param0_literal).ConsumeValueOrDie();
 | 
			
		||||
      client->TransferToServer(param0_literal).ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<xla::GlobalData> param1_data =
 | 
			
		||||
      client->TransferToServer(*param1_literal).ConsumeValueOrDie();
 | 
			
		||||
      client->TransferToServer(param1_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<xla::GlobalData> actual =
 | 
			
		||||
      client
 | 
			
		||||
          ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<xla::Literal> actual_literal =
 | 
			
		||||
      client->Transfer(*actual).ConsumeValueOrDie();
 | 
			
		||||
  xla::Literal actual_literal = client->Transfer(*actual).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<xla::Literal> expected0 =
 | 
			
		||||
      xla::LiteralUtil::CreateR1<int32>({5, 144});
 | 
			
		||||
  std::unique_ptr<xla::Literal> expected1 =
 | 
			
		||||
      xla::LiteralUtil::CreateR1<int32>({4, 143});
 | 
			
		||||
  std::unique_ptr<xla::Literal> expected_literal =
 | 
			
		||||
      xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
 | 
			
		||||
  EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
 | 
			
		||||
  xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({5, 144});
 | 
			
		||||
  xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({4, 143});
 | 
			
		||||
  xla::Literal expected_literal =
 | 
			
		||||
      xla::LiteralUtil::MakeTuple({&expected0, &expected1});
 | 
			
		||||
  EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Tests a simple graph that reads and writes a variable.
 | 
			
		||||
@ -952,20 +924,17 @@ TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) {
 | 
			
		||||
                                     std::move(graph), args, &result));
 | 
			
		||||
 | 
			
		||||
  // Tests that the generated computation works.
 | 
			
		||||
  std::unique_ptr<xla::Literal> param1_literal =
 | 
			
		||||
      xla::LiteralUtil::CreateR1<int32>({-3, 101});
 | 
			
		||||
  xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
 | 
			
		||||
  std::unique_ptr<xla::GlobalData> param1_data =
 | 
			
		||||
      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(param1_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<xla::GlobalData> actual =
 | 
			
		||||
      client_->Execute(*result.computation, {param1_data.get()})
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<xla::Literal> actual_literal =
 | 
			
		||||
      client_->Transfer(*actual).ConsumeValueOrDie();
 | 
			
		||||
  xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<xla::Literal> expected_literal =
 | 
			
		||||
      xla::LiteralUtil::MakeTuple({});
 | 
			
		||||
  EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
 | 
			
		||||
  xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({});
 | 
			
		||||
  EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(XlaCompilerTest, ReturnResourceHandle) {
 | 
			
		||||
@ -1069,29 +1038,27 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
 | 
			
		||||
           xla::ShapeUtil::MakeShape(xla::S32, {4})})));
 | 
			
		||||
 | 
			
		||||
  // Tests that the generated computation works.
 | 
			
		||||
  std::unique_ptr<xla::Literal> param0_literal =
 | 
			
		||||
  xla::Literal param0_literal =
 | 
			
		||||
      xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}});
 | 
			
		||||
  std::unique_ptr<xla::Literal> param1_literal =
 | 
			
		||||
  xla::Literal param1_literal =
 | 
			
		||||
      xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
 | 
			
		||||
  std::unique_ptr<xla::GlobalData> param0_data =
 | 
			
		||||
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<xla::GlobalData> param1_data =
 | 
			
		||||
      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(param1_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<xla::GlobalData> actual =
 | 
			
		||||
      client_
 | 
			
		||||
          ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<xla::Literal> actual_literal =
 | 
			
		||||
      client_->Transfer(*actual).ConsumeValueOrDie();
 | 
			
		||||
  xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<xla::Literal> expected0 =
 | 
			
		||||
  xla::Literal expected0 =
 | 
			
		||||
      xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}});
 | 
			
		||||
  std::unique_ptr<xla::Literal> expected1 =
 | 
			
		||||
      xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
 | 
			
		||||
  std::unique_ptr<xla::Literal> expected_literal =
 | 
			
		||||
      xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
 | 
			
		||||
  EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
 | 
			
		||||
  xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
 | 
			
		||||
  xla::Literal expected_literal =
 | 
			
		||||
      xla::LiteralUtil::MakeTuple({&expected0, &expected1});
 | 
			
		||||
  EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
 | 
			
		||||
@ -1138,29 +1105,26 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
 | 
			
		||||
           xla::ShapeUtil::MakeShape(xla::S32, {4})})));
 | 
			
		||||
 | 
			
		||||
  // Tests that the generated computation works.
 | 
			
		||||
  std::unique_ptr<xla::Literal> param0_literal =
 | 
			
		||||
  xla::Literal param0_literal =
 | 
			
		||||
      xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3});
 | 
			
		||||
  std::unique_ptr<xla::Literal> param1_literal =
 | 
			
		||||
  xla::Literal param1_literal =
 | 
			
		||||
      xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
 | 
			
		||||
  std::unique_ptr<xla::GlobalData> param0_data =
 | 
			
		||||
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<xla::GlobalData> param1_data =
 | 
			
		||||
      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(param1_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<xla::GlobalData> actual =
 | 
			
		||||
      client_
 | 
			
		||||
          ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<xla::Literal> actual_literal =
 | 
			
		||||
      client_->Transfer(*actual).ConsumeValueOrDie();
 | 
			
		||||
  xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<xla::Literal> expected0 =
 | 
			
		||||
      xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
 | 
			
		||||
  std::unique_ptr<xla::Literal> expected1 =
 | 
			
		||||
      xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
 | 
			
		||||
  std::unique_ptr<xla::Literal> expected_literal =
 | 
			
		||||
      xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
 | 
			
		||||
  EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
 | 
			
		||||
  xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
 | 
			
		||||
  xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
 | 
			
		||||
  xla::Literal expected_literal =
 | 
			
		||||
      xla::LiteralUtil::MakeTuple({&expected0, &expected1});
 | 
			
		||||
  EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Tests a graph which has a function with an invalid op.
 | 
			
		||||
 | 
			
		||||
@ -213,16 +213,15 @@ Status XlaOpKernelContext::ConstantInputReshaped(
 | 
			
		||||
        context_->op_kernel().name(), " input ", index,
 | 
			
		||||
        ".\nError: ", constant_graph.status().error_message());
 | 
			
		||||
  }
 | 
			
		||||
  xla::StatusOr<std::unique_ptr<xla::Literal>> computed =
 | 
			
		||||
      compiler()->client()->ComputeConstant(constant_graph.ValueOrDie(),
 | 
			
		||||
                                            &layout);
 | 
			
		||||
  xla::StatusOr<xla::Literal> computed = compiler()->client()->ComputeConstant(
 | 
			
		||||
      constant_graph.ValueOrDie(), &layout);
 | 
			
		||||
  if (!computed.ok()) {
 | 
			
		||||
    return errors::Internal("Error evaluating ", context_->op_kernel().name(),
 | 
			
		||||
                            " input ", index,
 | 
			
		||||
                            " as a compile-time constant.\nError: ",
 | 
			
		||||
                            computed.status().error_message());
 | 
			
		||||
  }
 | 
			
		||||
  *constant_literal = std::move(*computed.ValueOrDie());
 | 
			
		||||
  *constant_literal = std::move(computed).ValueOrDie();
 | 
			
		||||
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -37,8 +37,8 @@ Client::Client(ServiceInterface* stub) : stub_(stub) {}
 | 
			
		||||
 | 
			
		||||
Client::~Client() = default;
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> Client::Transfer(
 | 
			
		||||
    const GlobalData& data, const Shape* shape_with_layout) {
 | 
			
		||||
StatusOr<Literal> Client::Transfer(const GlobalData& data,
 | 
			
		||||
                                   const Shape* shape_with_layout) {
 | 
			
		||||
  TransferToClientRequest request;
 | 
			
		||||
  *request.mutable_data() = data.handle();
 | 
			
		||||
  if (shape_with_layout != nullptr) {
 | 
			
		||||
@ -114,7 +114,7 @@ Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id,
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed(
 | 
			
		||||
StatusOr<Literal> Client::TransferFromOutfeed(
 | 
			
		||||
    const Shape* shape_with_layout, int64 replica_id,
 | 
			
		||||
    const DeviceHandle* device_handle) {
 | 
			
		||||
  TransferFromOutfeedRequest request;
 | 
			
		||||
@ -162,7 +162,7 @@ Status Client::ResetDevice() {
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
 | 
			
		||||
StatusOr<Literal> Client::ExecuteAndTransfer(
 | 
			
		||||
    const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
 | 
			
		||||
    const ExecutionOptions* execution_options,
 | 
			
		||||
    ExecutionProfile* execution_profile) {
 | 
			
		||||
@ -177,8 +177,8 @@ StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
 | 
			
		||||
  return Transfer(*data, shape_with_output_layout);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> Client::ComputeConstant(
 | 
			
		||||
    const XlaComputation& computation, const Layout* output_layout) const {
 | 
			
		||||
StatusOr<Literal> Client::ComputeConstant(const XlaComputation& computation,
 | 
			
		||||
                                          const Layout* output_layout) const {
 | 
			
		||||
  ComputeConstantGraphRequest request;
 | 
			
		||||
  *request.mutable_computation() = computation.proto();
 | 
			
		||||
  if (output_layout != nullptr) {
 | 
			
		||||
 | 
			
		||||
@ -96,8 +96,8 @@ class Client {
 | 
			
		||||
  //
 | 
			
		||||
  // If shape_with_layout is not nullptr, it points to a shape whose layout will
 | 
			
		||||
  // be the layout of the returned literal.
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> Transfer(
 | 
			
		||||
      const GlobalData& data, const Shape* shape_with_layout = nullptr);
 | 
			
		||||
  StatusOr<Literal> Transfer(const GlobalData& data,
 | 
			
		||||
                             const Shape* shape_with_layout = nullptr);
 | 
			
		||||
 | 
			
		||||
  // Transfer the given literal to the server. This allocates memory on the
 | 
			
		||||
  // device and copies the literal's contents over. Returns a global data handle
 | 
			
		||||
@ -122,7 +122,7 @@ class Client {
 | 
			
		||||
  // device_handle and replica_id together specify a particular device; a device
 | 
			
		||||
  // assigned for the given replica_id among the replicas that the given device
 | 
			
		||||
  // handle belongs to.
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> TransferFromOutfeed(
 | 
			
		||||
  StatusOr<Literal> TransferFromOutfeed(
 | 
			
		||||
      const Shape* shape_with_layout, int64 replica_id = 0,
 | 
			
		||||
      const DeviceHandle* device_handle = nullptr);
 | 
			
		||||
 | 
			
		||||
@ -132,7 +132,7 @@ class Client {
 | 
			
		||||
  // Executes the computation with the given arguments and transfers the result
 | 
			
		||||
  // to the client as a literal. Parameters are defined the same as for
 | 
			
		||||
  // Execute() and Transfer().
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
 | 
			
		||||
  StatusOr<Literal> ExecuteAndTransfer(
 | 
			
		||||
      const XlaComputation& computation,
 | 
			
		||||
      absl::Span<GlobalData* const> arguments,
 | 
			
		||||
      const ExecutionOptions* execution_options = nullptr,
 | 
			
		||||
@ -153,7 +153,7 @@ class Client {
 | 
			
		||||
  //
 | 
			
		||||
  // If output_layout is non-null, then the output of the computation will be
 | 
			
		||||
  // stored using that layout.
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> ComputeConstant(
 | 
			
		||||
  StatusOr<Literal> ComputeConstant(
 | 
			
		||||
      const XlaComputation& computation,
 | 
			
		||||
      const Layout* output_layout = nullptr) const;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -76,7 +76,7 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
 | 
			
		||||
std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
 | 
			
		||||
                                              Client* client) {
 | 
			
		||||
  if (DataSizeOfShape(shape) < (1LL << 20)) {
 | 
			
		||||
    StatusOr<std::unique_ptr<Literal>> literal_status = MakeFakeLiteral(shape);
 | 
			
		||||
    StatusOr<Literal> literal_status = MakeFakeLiteral(shape);
 | 
			
		||||
    if (!literal_status.ok()) {
 | 
			
		||||
      // If we got an Unimplemented error, fall back to making the fake data via
 | 
			
		||||
      // an on-device computation.
 | 
			
		||||
@ -84,7 +84,7 @@ std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
 | 
			
		||||
               tensorflow::error::UNIMPLEMENTED);
 | 
			
		||||
      return MakeFakeDataViaDeviceOrDie(shape, client);
 | 
			
		||||
    }
 | 
			
		||||
    return client->TransferToServer(*literal_status.ValueOrDie()).ValueOrDie();
 | 
			
		||||
    return client->TransferToServer(literal_status.ValueOrDie()).ValueOrDie();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // If the data is large, generate it on-device.
 | 
			
		||||
 | 
			
		||||
@ -195,9 +195,8 @@ Status LocalExecutable::RecordArguments(
 | 
			
		||||
    HloSnapshot* hlo_snapshot) {
 | 
			
		||||
  hlo_snapshot->clear_arguments();
 | 
			
		||||
  for (const ShapedBuffer* argument : arguments) {
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
 | 
			
		||||
                        LiteralFromShapedBuffer(*argument));
 | 
			
		||||
    *hlo_snapshot->add_arguments() = literal->ToProto();
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*argument));
 | 
			
		||||
    *hlo_snapshot->add_arguments() = literal.ToProto();
 | 
			
		||||
  }
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
@ -205,13 +204,12 @@ Status LocalExecutable::RecordArguments(
 | 
			
		||||
Status LocalExecutable::RecordResult(const ShapedBuffer* result,
 | 
			
		||||
                                     HloSnapshot* hlo_snapshot) {
 | 
			
		||||
  hlo_snapshot->clear_result();
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
 | 
			
		||||
                      LiteralFromShapedBuffer(*result));
 | 
			
		||||
  *hlo_snapshot->mutable_result() = literal->ToProto();
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*result));
 | 
			
		||||
  *hlo_snapshot->mutable_result() = literal.ToProto();
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> LocalExecutable::LiteralFromShapedBuffer(
 | 
			
		||||
StatusOr<Literal> LocalExecutable::LiteralFromShapedBuffer(
 | 
			
		||||
    const ShapedBuffer& shaped_buffer) {
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(auto stream,
 | 
			
		||||
                      backend_->BorrowStream(shaped_buffer.device_ordinal()));
 | 
			
		||||
@ -277,7 +275,7 @@ StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(
 | 
			
		||||
  return std::move(scoped_buffer);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral(
 | 
			
		||||
StatusOr<Literal> LocalClient::ShapedBufferToLiteral(
 | 
			
		||||
    const ShapedBuffer& shaped_buffer) {
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream(
 | 
			
		||||
                                       shaped_buffer.device_ordinal()));
 | 
			
		||||
@ -298,13 +296,13 @@ Status LocalClient::TransferToInfeedLocal(const Literal& literal,
 | 
			
		||||
                                                               literal);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> LocalClient::TransferFromOutfeedLocal(
 | 
			
		||||
    const Shape& shape, int device_ordinal) {
 | 
			
		||||
StatusOr<Literal> LocalClient::TransferFromOutfeedLocal(const Shape& shape,
 | 
			
		||||
                                                        int device_ordinal) {
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
 | 
			
		||||
                      backend().stream_executor(device_ordinal));
 | 
			
		||||
  auto literal = Literal::CreateFromShape(shape);
 | 
			
		||||
  TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
 | 
			
		||||
      executor, shape, literal.get()));
 | 
			
		||||
      executor, shape, &literal));
 | 
			
		||||
  return std::move(literal);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -84,8 +84,7 @@ class LocalExecutable {
 | 
			
		||||
  Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot);
 | 
			
		||||
 | 
			
		||||
  // Returns a literal containing the contents of the given ShapedBuffer.
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> LiteralFromShapedBuffer(
 | 
			
		||||
      const ShapedBuffer& shaped_buffer);
 | 
			
		||||
  StatusOr<Literal> LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer);
 | 
			
		||||
 | 
			
		||||
  // The ordinal of the device which this executable was compiled for. The
 | 
			
		||||
  // executable can run on all equivalent devices (as determined by
 | 
			
		||||
@ -132,8 +131,7 @@ class LocalClient : public Client {
 | 
			
		||||
 | 
			
		||||
  // Copy the data from the device contained in the given ShapedBuffer and
 | 
			
		||||
  // return as a Literal.
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> ShapedBufferToLiteral(
 | 
			
		||||
      const ShapedBuffer& shaped_buffer);
 | 
			
		||||
  StatusOr<Literal> ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer);
 | 
			
		||||
 | 
			
		||||
  // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
 | 
			
		||||
  // as long as the handle is valid.
 | 
			
		||||
@ -151,8 +149,8 @@ class LocalClient : public Client {
 | 
			
		||||
  // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
 | 
			
		||||
  // not inherit from Client and there is no possibility of confusion with
 | 
			
		||||
  // Client::TransferFromOutfeed.
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocal(
 | 
			
		||||
      const Shape& shape, int device_ordinal);
 | 
			
		||||
  StatusOr<Literal> TransferFromOutfeedLocal(const Shape& shape,
 | 
			
		||||
                                             int device_ordinal);
 | 
			
		||||
 | 
			
		||||
  // Returns the device ordinal that corresponds to the given replica number.
 | 
			
		||||
  //
 | 
			
		||||
 | 
			
		||||
@ -738,7 +738,7 @@ void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
 | 
			
		||||
  ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
 | 
			
		||||
    HloInstructionProto instr;
 | 
			
		||||
    *instr.mutable_shape() = ShapeUtil::MakeNil();
 | 
			
		||||
    *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag)->ToProto();
 | 
			
		||||
    *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto();
 | 
			
		||||
    return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -2112,12 +2112,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
XlaOp XlaBuilder::ConstantR0(NativeT value) {
 | 
			
		||||
  return ConstantLiteral(*LiteralUtil::CreateR0<NativeT>(value));
 | 
			
		||||
  return ConstantLiteral(LiteralUtil::CreateR0<NativeT>(value));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
XlaOp XlaBuilder::ConstantR1(absl::Span<const NativeT> values) {
 | 
			
		||||
  return ConstantLiteral(*LiteralUtil::CreateR1<NativeT>(values));
 | 
			
		||||
  return ConstantLiteral(LiteralUtil::CreateR1<NativeT>(values));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
@ -2129,44 +2129,44 @@ XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) {
 | 
			
		||||
  return ConstantLiteral(*LiteralUtil::CreateR1(values));
 | 
			
		||||
  return ConstantLiteral(LiteralUtil::CreateR1(values));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
XlaOp XlaBuilder::ConstantR2(
 | 
			
		||||
    std::initializer_list<std::initializer_list<NativeT>> values) {
 | 
			
		||||
  return ConstantLiteral(*LiteralUtil::CreateR2<NativeT>(values));
 | 
			
		||||
  return ConstantLiteral(LiteralUtil::CreateR2<NativeT>(values));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array<NativeT>& values,
 | 
			
		||||
                                              const Layout& layout) {
 | 
			
		||||
  return ConstantLiteral(
 | 
			
		||||
      *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
 | 
			
		||||
      LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
XlaOp XlaBuilder::ConstantFromArray(const Array<NativeT>& values) {
 | 
			
		||||
  return ConstantLiteral(*LiteralUtil::CreateFromArray<NativeT>(values));
 | 
			
		||||
  return ConstantLiteral(LiteralUtil::CreateFromArray<NativeT>(values));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout(
 | 
			
		||||
    const Array2D<NativeT>& values, const Layout& layout) {
 | 
			
		||||
  return ConstantLiteral(
 | 
			
		||||
      *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
 | 
			
		||||
      LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D<NativeT>& values) {
 | 
			
		||||
  return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D<NativeT>(values));
 | 
			
		||||
  return ConstantLiteral(LiteralUtil::CreateR2FromArray2D<NativeT>(values));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout(
 | 
			
		||||
    const Array3D<NativeT>& values, const Layout& layout) {
 | 
			
		||||
  return ConstantLiteral(
 | 
			
		||||
      *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
 | 
			
		||||
      LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
@ -2189,12 +2189,12 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) {
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
 | 
			
		||||
  return ConstantLiteral(builder, *LiteralUtil::CreateR0<NativeT>(value));
 | 
			
		||||
  return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
 | 
			
		||||
  return ConstantLiteral(builder, *LiteralUtil::CreateR1<NativeT>(values));
 | 
			
		||||
  return ConstantLiteral(builder, LiteralUtil::CreateR1<NativeT>(values));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
@ -2207,13 +2207,13 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
 | 
			
		||||
 | 
			
		||||
inline XlaOp ConstantR1(XlaBuilder* builder,
 | 
			
		||||
                        const tensorflow::core::Bitmap& values) {
 | 
			
		||||
  return ConstantLiteral(builder, *LiteralUtil::CreateR1(values));
 | 
			
		||||
  return ConstantLiteral(builder, LiteralUtil::CreateR1(values));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
XlaOp ConstantR2(XlaBuilder* builder,
 | 
			
		||||
                 std::initializer_list<std::initializer_list<NativeT>> values) {
 | 
			
		||||
  return ConstantLiteral(builder, *LiteralUtil::CreateR2<NativeT>(values));
 | 
			
		||||
  return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
@ -2221,14 +2221,13 @@ XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
 | 
			
		||||
                                  const Array<NativeT>& values,
 | 
			
		||||
                                  const Layout& layout) {
 | 
			
		||||
  return ConstantLiteral(
 | 
			
		||||
      builder,
 | 
			
		||||
      *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
 | 
			
		||||
      builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
 | 
			
		||||
  return ConstantLiteral(builder,
 | 
			
		||||
                         *LiteralUtil::CreateFromArray<NativeT>(values));
 | 
			
		||||
                         LiteralUtil::CreateFromArray<NativeT>(values));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
@ -2236,15 +2235,14 @@ XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
 | 
			
		||||
                                      const Array2D<NativeT>& values,
 | 
			
		||||
                                      const Layout& layout) {
 | 
			
		||||
  return ConstantLiteral(
 | 
			
		||||
      builder,
 | 
			
		||||
      *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
 | 
			
		||||
      builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
 | 
			
		||||
                            const Array2D<NativeT>& values) {
 | 
			
		||||
  return ConstantLiteral(builder,
 | 
			
		||||
                         *LiteralUtil::CreateR2FromArray2D<NativeT>(values));
 | 
			
		||||
                         LiteralUtil::CreateR2FromArray2D<NativeT>(values));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
@ -2253,7 +2251,7 @@ XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
 | 
			
		||||
                                      const Layout& layout) {
 | 
			
		||||
  return ConstantLiteral(
 | 
			
		||||
      builder,
 | 
			
		||||
      *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
 | 
			
		||||
      LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
 | 
			
		||||
@ -174,9 +174,9 @@ Literal& Literal::operator=(Literal&& other) {
 | 
			
		||||
  return *this;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) {
 | 
			
		||||
  auto literal = absl::make_unique<Literal>(shape);
 | 
			
		||||
  literal->root_piece_->ForEachMutableSubpiece(
 | 
			
		||||
Literal LiteralBase::CreateFromShape(const Shape& shape) {
 | 
			
		||||
  Literal literal(shape);
 | 
			
		||||
  literal.root_piece_->ForEachMutableSubpiece(
 | 
			
		||||
      [&](const ShapeIndex& index, Piece* piece) {
 | 
			
		||||
        if (ShapeUtil::IsArray(piece->subshape())) {
 | 
			
		||||
          memset(piece->untyped_data(), 0, piece->size_bytes());
 | 
			
		||||
@ -278,8 +278,8 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ StatusOr<std::unique_ptr<Literal>>
 | 
			
		||||
MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
 | 
			
		||||
/* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
 | 
			
		||||
    const LiteralProto& proto) {
 | 
			
		||||
  if (!proto.has_shape()) {
 | 
			
		||||
    return InvalidArgument("LiteralProto has no shape");
 | 
			
		||||
  }
 | 
			
		||||
@ -287,9 +287,9 @@ MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
 | 
			
		||||
    return InvalidArgument("LiteralProto has no layout");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto literal = absl::make_unique<Literal>(proto.shape());
 | 
			
		||||
  Literal literal(proto.shape());
 | 
			
		||||
 | 
			
		||||
  TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus(
 | 
			
		||||
  TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
 | 
			
		||||
      [&](const ShapeIndex& index, Piece* piece) {
 | 
			
		||||
        const LiteralProto* proto_element = &proto;
 | 
			
		||||
        for (int64 i : index) {
 | 
			
		||||
@ -556,38 +556,37 @@ void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<Literal> LiteralBase::Relayout(
 | 
			
		||||
    const Layout& new_layout, const ShapeIndex& shape_index) const {
 | 
			
		||||
Literal LiteralBase::Relayout(const Layout& new_layout,
 | 
			
		||||
                              const ShapeIndex& shape_index) const {
 | 
			
		||||
  // Create new shape with 'new_layout' set at the given shape index.
 | 
			
		||||
  Shape new_shape = shape();
 | 
			
		||||
  Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
 | 
			
		||||
  TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
 | 
			
		||||
  *subshape->mutable_layout() = new_layout;
 | 
			
		||||
  auto result = absl::make_unique<Literal>(new_shape);
 | 
			
		||||
  TF_CHECK_OK(result->CopyFrom(*this));
 | 
			
		||||
  Literal result(new_shape);
 | 
			
		||||
  TF_CHECK_OK(result.CopyFrom(*this));
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<Literal> LiteralBase::Relayout(
 | 
			
		||||
    const Shape& shape_with_layout) const {
 | 
			
		||||
Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
 | 
			
		||||
  CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
 | 
			
		||||
      << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
 | 
			
		||||
      << " not compatible with literal shape "
 | 
			
		||||
      << ShapeUtil::HumanString(shape());
 | 
			
		||||
  std::unique_ptr<Literal> result = CreateFromShape(shape_with_layout);
 | 
			
		||||
  Literal result = CreateFromShape(shape_with_layout);
 | 
			
		||||
  ShapeUtil::ForEachSubshape(
 | 
			
		||||
      result->shape(),
 | 
			
		||||
      result.shape(),
 | 
			
		||||
      [this, &result](const Shape& subshape, const ShapeIndex& index) {
 | 
			
		||||
        if (ShapeUtil::IsArray(subshape)) {
 | 
			
		||||
          TF_CHECK_OK(result->CopyFrom(*this,
 | 
			
		||||
                                       /*dest_shape_index=*/index,
 | 
			
		||||
                                       /*src_shape_index=*/index));
 | 
			
		||||
          TF_CHECK_OK(result.CopyFrom(*this,
 | 
			
		||||
                                      /*dest_shape_index=*/index,
 | 
			
		||||
                                      /*src_shape_index=*/index));
 | 
			
		||||
        }
 | 
			
		||||
      });
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
 | 
			
		||||
StatusOr<Literal> LiteralBase::Broadcast(
 | 
			
		||||
    const Shape& result_shape, absl::Span<const int64> dimensions) const {
 | 
			
		||||
  if (!ShapeUtil::IsArray(shape())) {
 | 
			
		||||
    return InvalidArgument("Broadcast only supports arrays.");
 | 
			
		||||
@ -598,14 +597,14 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
 | 
			
		||||
                 result_shape.dimensions(dimensions[i]));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> result = absl::make_unique<Literal>(result_shape);
 | 
			
		||||
  Literal result(result_shape);
 | 
			
		||||
 | 
			
		||||
  // scratch_source_index is temporary storage space for the computed index into
 | 
			
		||||
  // the input literal.  We put it here to avoid allocating an std::vector in
 | 
			
		||||
  // every iteration of ShapeUtil::ForEachIndex.
 | 
			
		||||
  std::vector<int64> scratch_source_index(shape().dimensions_size());
 | 
			
		||||
 | 
			
		||||
  char* dest_data = static_cast<char*>(result->untyped_data());
 | 
			
		||||
  char* dest_data = static_cast<char*>(result.untyped_data());
 | 
			
		||||
  const char* source_data = static_cast<const char*>(untyped_data());
 | 
			
		||||
  const int64 primitive_size =
 | 
			
		||||
      ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
 | 
			
		||||
@ -627,37 +626,36 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
 | 
			
		||||
  return std::move(result);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
 | 
			
		||||
StatusOr<Literal> LiteralBase::Reshape(
 | 
			
		||||
    absl::Span<const int64> dimensions) const {
 | 
			
		||||
  if (!ShapeUtil::IsArray(shape())) {
 | 
			
		||||
    return InvalidArgument("Reshape does not support tuples.");
 | 
			
		||||
  }
 | 
			
		||||
  std::unique_ptr<Literal> output;
 | 
			
		||||
  Literal output;
 | 
			
		||||
  if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
 | 
			
		||||
    output =
 | 
			
		||||
        Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape())));
 | 
			
		||||
  } else {
 | 
			
		||||
    output = CloneToUnique();
 | 
			
		||||
    output = Clone();
 | 
			
		||||
  }
 | 
			
		||||
  // Because the layout is monotonic, we can simply reuse the same sequence of
 | 
			
		||||
  // values without changing their order.
 | 
			
		||||
  *output->mutable_shape_do_not_use() =
 | 
			
		||||
  *output.mutable_shape_do_not_use() =
 | 
			
		||||
      ShapeUtil::MakeShape(shape().element_type(), dimensions);
 | 
			
		||||
 | 
			
		||||
  int64 elements_before = ShapeUtil::ElementsIn(shape());
 | 
			
		||||
  int64 elements_after = ShapeUtil::ElementsIn(output->shape());
 | 
			
		||||
  int64 elements_after = ShapeUtil::ElementsIn(output.shape());
 | 
			
		||||
  if (elements_before != elements_after) {
 | 
			
		||||
    return InvalidArgument(
 | 
			
		||||
        "Shapes before and after Literal::Reshape have different numbers "
 | 
			
		||||
        "of elements: %s vs %s.",
 | 
			
		||||
        ShapeUtil::HumanString(shape()),
 | 
			
		||||
        ShapeUtil::HumanString(output->shape()));
 | 
			
		||||
        ShapeUtil::HumanString(output.shape()));
 | 
			
		||||
  }
 | 
			
		||||
  return std::move(output);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<Literal> LiteralBase::Transpose(
 | 
			
		||||
    absl::Span<const int64> permutation) const {
 | 
			
		||||
Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
 | 
			
		||||
  CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
 | 
			
		||||
  CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
 | 
			
		||||
      << "Given permutation is not a permutation of dimension numbers";
 | 
			
		||||
@ -687,32 +685,31 @@ std::unique_ptr<Literal> LiteralBase::Transpose(
 | 
			
		||||
  for (auto index : LayoutUtil::MinorToMajor(shape())) {
 | 
			
		||||
    layout->add_minor_to_major(inverse_permutation[index]);
 | 
			
		||||
  }
 | 
			
		||||
  auto new_literal = absl::make_unique<Literal>(permuted_shape);
 | 
			
		||||
  DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()),
 | 
			
		||||
  Literal new_literal(permuted_shape);
 | 
			
		||||
  DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
 | 
			
		||||
            ShapeUtil::ByteSizeOf(shape()));
 | 
			
		||||
  std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes());
 | 
			
		||||
  std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
 | 
			
		||||
  return new_literal;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
std::unique_ptr<Literal> LiteralBase::SliceInternal(
 | 
			
		||||
Literal LiteralBase::SliceInternal(
 | 
			
		||||
    const Shape& result_shape, absl::Span<const int64> start_indices) const {
 | 
			
		||||
  auto result_literal = absl::make_unique<Literal>(result_shape);
 | 
			
		||||
  Literal result_literal(result_shape);
 | 
			
		||||
  DimensionVector new_indices(ShapeUtil::Rank(result_shape));
 | 
			
		||||
  result_literal->EachCell<NativeT>(
 | 
			
		||||
  result_literal.EachCell<NativeT>(
 | 
			
		||||
      [&](absl::Span<const int64> indices, NativeT /*value*/) {
 | 
			
		||||
        for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
 | 
			
		||||
          new_indices[i] = indices[i] + start_indices[i];
 | 
			
		||||
        }
 | 
			
		||||
        NativeT value = Get<NativeT>(new_indices);
 | 
			
		||||
        result_literal->Set<NativeT>(indices, value);
 | 
			
		||||
        result_literal.Set<NativeT>(indices, value);
 | 
			
		||||
      });
 | 
			
		||||
  return result_literal;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<Literal> LiteralBase::Slice(
 | 
			
		||||
    absl::Span<const int64> start_indices,
 | 
			
		||||
    absl::Span<const int64> limit_indices) const {
 | 
			
		||||
Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
 | 
			
		||||
                           absl::Span<const int64> limit_indices) const {
 | 
			
		||||
  CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
 | 
			
		||||
 | 
			
		||||
  DimensionVector result_dimensions;
 | 
			
		||||
@ -750,12 +747,6 @@ Literal LiteralBase::Clone() const {
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
 | 
			
		||||
  auto result = absl::make_unique<Literal>(shape());
 | 
			
		||||
  TF_CHECK_OK(result->CopyFrom(*this));
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
 | 
			
		||||
                                const ShapeIndex& shape_index) const {
 | 
			
		||||
  const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
 | 
			
		||||
@ -1191,14 +1182,14 @@ void LiteralBase::EachCellAsString(
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
 | 
			
		||||
std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
 | 
			
		||||
    const LiteralBase& src_literal, const ConverterType& converter) {
 | 
			
		||||
Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
 | 
			
		||||
                                               const ConverterType& converter) {
 | 
			
		||||
  CHECK(ShapeUtil::IsArray(src_literal.shape()));
 | 
			
		||||
  auto result_literal = absl::make_unique<Literal>(ShapeUtil::ChangeElementType(
 | 
			
		||||
  Literal result_literal(ShapeUtil::ChangeElementType(
 | 
			
		||||
      src_literal.shape(),
 | 
			
		||||
      primitive_util::NativeToPrimitiveType<NativeDestT>()));
 | 
			
		||||
  auto src_data = src_literal.data<NativeSrcT>();
 | 
			
		||||
  auto dest_data = result_literal->template data<NativeDestT>();
 | 
			
		||||
  auto dest_data = result_literal.template data<NativeDestT>();
 | 
			
		||||
  int64 num_elements = src_literal.element_count();
 | 
			
		||||
 | 
			
		||||
  for (int64 i = 0; i < num_elements; ++i) {
 | 
			
		||||
@ -1208,8 +1199,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeSrcT, typename NativeDestT>
 | 
			
		||||
std::unique_ptr<Literal> ConvertBetweenNativeTypes(
 | 
			
		||||
    const LiteralBase& src_literal) {
 | 
			
		||||
Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
 | 
			
		||||
  auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
 | 
			
		||||
  return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
 | 
			
		||||
      src_literal, converter);
 | 
			
		||||
@ -1217,7 +1207,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypes(
 | 
			
		||||
 | 
			
		||||
template <typename NativeSrcT, typename NativeDestT>
 | 
			
		||||
typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)),
 | 
			
		||||
                        std::unique_ptr<Literal>>::type
 | 
			
		||||
                        Literal>::type
 | 
			
		||||
BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
 | 
			
		||||
  auto converter = [](NativeSrcT src) {
 | 
			
		||||
    return tensorflow::bit_cast<NativeDestT>(src);
 | 
			
		||||
@ -1232,20 +1222,20 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
 | 
			
		||||
// identical sizes higher up.
 | 
			
		||||
template <typename NativeSrcT, typename NativeDestT>
 | 
			
		||||
typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
 | 
			
		||||
                        std::unique_ptr<Literal>>::type
 | 
			
		||||
                        Literal>::type
 | 
			
		||||
BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
 | 
			
		||||
  LOG(FATAL) << "Invalid bitcast between types of different sizes.";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <PrimitiveType primitive_src_type>
 | 
			
		||||
std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
 | 
			
		||||
Literal ConvertToC64(const LiteralBase& src_literal) {
 | 
			
		||||
  CHECK(ShapeUtil::IsArray(src_literal.shape()));
 | 
			
		||||
  auto result_literal = absl::make_unique<Literal>(
 | 
			
		||||
  Literal result_literal(
 | 
			
		||||
      ShapeUtil::ChangeElementType(src_literal.shape(), C64));
 | 
			
		||||
  using NativeSrcT =
 | 
			
		||||
      typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
 | 
			
		||||
  absl::Span<const NativeSrcT> src_data = src_literal.data<NativeSrcT>();
 | 
			
		||||
  absl::Span<complex64> dest_data = result_literal->data<complex64>();
 | 
			
		||||
  absl::Span<complex64> dest_data = result_literal.data<complex64>();
 | 
			
		||||
  int64 num_elements = src_literal.element_count();
 | 
			
		||||
  for (int64 i = 0; i < num_elements; ++i) {
 | 
			
		||||
    dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
 | 
			
		||||
@ -1254,8 +1244,7 @@ std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
 | 
			
		||||
std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
 | 
			
		||||
                                             bool bitcast) {
 | 
			
		||||
Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
 | 
			
		||||
  CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
 | 
			
		||||
  if (bitcast) {
 | 
			
		||||
    return BitcastBetweenNativeTypes<
 | 
			
		||||
@ -1273,9 +1262,9 @@ std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <PrimitiveType primitive_src_type>
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
 | 
			
		||||
    const LiteralBase& src_literal, PrimitiveType primitive_dest_type,
 | 
			
		||||
    bool bitcast) {
 | 
			
		||||
StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
 | 
			
		||||
                                           PrimitiveType primitive_dest_type,
 | 
			
		||||
                                           bool bitcast) {
 | 
			
		||||
  switch (primitive_dest_type) {
 | 
			
		||||
#define CONVERT_IF_TYPES_MATCH(type)                                    \
 | 
			
		||||
  case (type):                                                          \
 | 
			
		||||
@ -1307,12 +1296,12 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
 | 
			
		||||
                       PrimitiveType_Name(primitive_dest_type));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
 | 
			
		||||
    const LiteralBase& literal, PrimitiveType primitive_dest_type,
 | 
			
		||||
    bool bitcast) {
 | 
			
		||||
StatusOr<Literal> ConvertSwitch(const LiteralBase& literal,
 | 
			
		||||
                                PrimitiveType primitive_dest_type,
 | 
			
		||||
                                bool bitcast) {
 | 
			
		||||
  TF_RET_CHECK(ShapeUtil::IsArray(literal.shape()));
 | 
			
		||||
  if (literal.shape().element_type() == primitive_dest_type) {
 | 
			
		||||
    return literal.CloneToUnique();
 | 
			
		||||
    return literal.Clone();
 | 
			
		||||
  }
 | 
			
		||||
  switch (literal.shape().element_type()) {
 | 
			
		||||
#define CONVERT_IF_DEST_TYPE_MATCHES(type)                                \
 | 
			
		||||
@ -1342,12 +1331,12 @@ StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> LiteralBase::Convert(
 | 
			
		||||
StatusOr<Literal> LiteralBase::Convert(
 | 
			
		||||
    PrimitiveType primitive_dest_type) const {
 | 
			
		||||
  return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
 | 
			
		||||
StatusOr<Literal> LiteralBase::BitcastConvert(
 | 
			
		||||
    PrimitiveType primitive_dest_type) const {
 | 
			
		||||
  if (primitive_util::BitWidth(shape().element_type()) !=
 | 
			
		||||
      primitive_util::BitWidth(primitive_dest_type)) {
 | 
			
		||||
@ -1362,8 +1351,8 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
 | 
			
		||||
  return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
 | 
			
		||||
    const Shape& dest_shape, bool round_f32_to_bf16) const {
 | 
			
		||||
StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape,
 | 
			
		||||
                                              bool round_f32_to_bf16) const {
 | 
			
		||||
  if (!ShapeUtil::IsTuple(dest_shape)) {
 | 
			
		||||
    if (round_f32_to_bf16 && shape().element_type() == F32 &&
 | 
			
		||||
        dest_shape.element_type() == BF16) {
 | 
			
		||||
@ -1381,11 +1370,9 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(
 | 
			
		||||
        auto new_element,
 | 
			
		||||
        element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
 | 
			
		||||
    elements.push_back(std::move(*new_element));
 | 
			
		||||
    elements.push_back(std::move(new_element));
 | 
			
		||||
  }
 | 
			
		||||
  auto converted = absl::make_unique<Literal>();
 | 
			
		||||
  *converted = MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
 | 
			
		||||
  return std::move(converted);
 | 
			
		||||
  return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ Literal MutableLiteralBase::MoveIntoTuple(
 | 
			
		||||
 | 
			
		||||
@ -223,25 +223,21 @@ class LiteralBase {
 | 
			
		||||
  //
 | 
			
		||||
  // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
 | 
			
		||||
  // the default behavior.
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> ConvertToShape(
 | 
			
		||||
      const Shape& dest_shape, bool round_f32_to_bf16 = false) const;
 | 
			
		||||
  StatusOr<Literal> ConvertToShape(const Shape& dest_shape,
 | 
			
		||||
                                   bool round_f32_to_bf16 = false) const;
 | 
			
		||||
 | 
			
		||||
  // Converts this literal to another primitive type using a bitcast
 | 
			
		||||
  // conversion. The to and from primitive types must have the same bit
 | 
			
		||||
  // width. Returns an error if the conversion is not possible. This literal
 | 
			
		||||
  // must be array-shaped.
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> BitcastConvert(
 | 
			
		||||
      PrimitiveType primitive_dest_type) const;
 | 
			
		||||
  StatusOr<Literal> BitcastConvert(PrimitiveType primitive_dest_type) const;
 | 
			
		||||
 | 
			
		||||
  // Converts this literal to another primitive type. Returns an error if the
 | 
			
		||||
  // conversion is not possible. This literal must be array-shaped.
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> Convert(
 | 
			
		||||
      PrimitiveType primitive_dest_type) const;
 | 
			
		||||
  StatusOr<Literal> Convert(PrimitiveType primitive_dest_type) const;
 | 
			
		||||
 | 
			
		||||
  // Clones the underlying buffers into a new Literal, or new
 | 
			
		||||
  // std::unique_ptr<Literal>.
 | 
			
		||||
  // Clones the underlying buffers into a new Literal.
 | 
			
		||||
  Literal Clone() const;
 | 
			
		||||
  std::unique_ptr<Literal> CloneToUnique() const;
 | 
			
		||||
 | 
			
		||||
  // TODO(b/67651157): The methods below which perform computation on Literals
 | 
			
		||||
  // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
 | 
			
		||||
@ -259,24 +255,23 @@ class LiteralBase {
 | 
			
		||||
  // Note: this is useful when the client wants to ensure that a value placed in
 | 
			
		||||
  // the XLA allocation tracker has a particular layout; for efficiency
 | 
			
		||||
  // purposes or avoiding unimplemented operation/layout combinations.
 | 
			
		||||
  std::unique_ptr<Literal> Relayout(const Layout& new_layout,
 | 
			
		||||
                                    const ShapeIndex& shape_index = {}) const;
 | 
			
		||||
  Literal Relayout(const Layout& new_layout,
 | 
			
		||||
                   const ShapeIndex& shape_index = {}) const;
 | 
			
		||||
 | 
			
		||||
  // An overload of Relayout which changes the layout of the entire shape rather
 | 
			
		||||
  // than being limited to a single array within the shape.
 | 
			
		||||
  std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
 | 
			
		||||
  Literal Relayout(const Shape& shape_with_layout) const;
 | 
			
		||||
 | 
			
		||||
  // Creates a new literal by reshaping this literal to have the given
 | 
			
		||||
  // dimensions. The total number of elements must not change; The
 | 
			
		||||
  // implementation currently only supports monotonic dim0-major layouts.
 | 
			
		||||
  // This literal must be an array.
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> Reshape(
 | 
			
		||||
      absl::Span<const int64> dimensions) const;
 | 
			
		||||
  StatusOr<Literal> Reshape(absl::Span<const int64> dimensions) const;
 | 
			
		||||
 | 
			
		||||
  // Creates a new literal by broadcasting this literal with `dimensions` to
 | 
			
		||||
  // yield a literal of shape `result_shape`.
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> Broadcast(
 | 
			
		||||
      const Shape& result_shape, absl::Span<const int64> dimensions) const;
 | 
			
		||||
  StatusOr<Literal> Broadcast(const Shape& result_shape,
 | 
			
		||||
                              absl::Span<const int64> dimensions) const;
 | 
			
		||||
 | 
			
		||||
  // Creates a new literal by reordering the dimensions of this literal.
 | 
			
		||||
  // The given `permutation` must be a permutation of the dimension numbers
 | 
			
		||||
@ -285,7 +280,7 @@ class LiteralBase {
 | 
			
		||||
  // For example, a transpose call on a literal of shape [3 x 8 x 4] and
 | 
			
		||||
  // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
 | 
			
		||||
  // This literal must be an array.
 | 
			
		||||
  std::unique_ptr<Literal> Transpose(absl::Span<const int64> permutation) const;
 | 
			
		||||
  Literal Transpose(absl::Span<const int64> permutation) const;
 | 
			
		||||
 | 
			
		||||
  // Creates a sub-array from this literal by extracting the indices
 | 
			
		||||
  // [start_index, limit_index) of each dimension. The result literal has the
 | 
			
		||||
@ -293,15 +288,15 @@ class LiteralBase {
 | 
			
		||||
  // start_indices and limit_indices must be the rank of the literal, and the
 | 
			
		||||
  // indices follow the order of the dimensions.
 | 
			
		||||
  // This literal must be an array.
 | 
			
		||||
  std::unique_ptr<Literal> Slice(absl::Span<const int64> start_indices,
 | 
			
		||||
                                 absl::Span<const int64> limit_indices) const;
 | 
			
		||||
  Literal Slice(absl::Span<const int64> start_indices,
 | 
			
		||||
                absl::Span<const int64> limit_indices) const;
 | 
			
		||||
 | 
			
		||||
  // Creates a literal with a prepended dimension with bound "times"; e.g. a
 | 
			
		||||
  // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
 | 
			
		||||
  // literal replicated four times.
 | 
			
		||||
  // This literal must be an array.
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  std::unique_ptr<Literal> Replicate(int64 times) const;
 | 
			
		||||
  Literal Replicate(int64 times) const;
 | 
			
		||||
 | 
			
		||||
  // Creates a new Literal object with the shape specified as parameter.
 | 
			
		||||
  // The content of the literal values is the default value of the primitive
 | 
			
		||||
@ -312,7 +307,7 @@ class LiteralBase {
 | 
			
		||||
  // initialization, then reinitialization. Conside if a call to
 | 
			
		||||
  // absl::make_unique<Literal>(shape), followed by the call to
 | 
			
		||||
  // MutableLiteralBase::Populate can be used instead.
 | 
			
		||||
  static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
 | 
			
		||||
  static Literal CreateFromShape(const Shape& shape);
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  // A data structure representing a subshape at a particular ShapeIndex within
 | 
			
		||||
@ -539,8 +534,8 @@ class LiteralBase {
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  std::unique_ptr<Literal> SliceInternal(
 | 
			
		||||
      const Shape& result_shape, absl::Span<const int64> start_indices) const;
 | 
			
		||||
  Literal SliceInternal(const Shape& result_shape,
 | 
			
		||||
                        absl::Span<const int64> start_indices) const;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Abstract base class representing a mutable literal in XLA.
 | 
			
		||||
@ -687,8 +682,7 @@ class MutableLiteralBase : public LiteralBase {
 | 
			
		||||
  static Literal MoveIntoTuple(absl::Span<Literal> elements);
 | 
			
		||||
 | 
			
		||||
  // Serialize from a proto.
 | 
			
		||||
  static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
 | 
			
		||||
      const LiteralProto& proto);
 | 
			
		||||
  static StatusOr<Literal> CreateFromProto(const LiteralProto& proto);
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  // Returns the piece at the given ShapeIndex.
 | 
			
		||||
@ -1137,15 +1131,14 @@ void MutableLiteralBase::PopulateWithValue(NativeT value) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
 | 
			
		||||
Literal LiteralBase::Replicate(int64 times) const {
 | 
			
		||||
  DimensionVector bounds = {times};
 | 
			
		||||
  bounds.reserve(shape().dimensions_size() + 1);
 | 
			
		||||
  for (int64 bound : shape().dimensions()) {
 | 
			
		||||
    bounds.push_back(bound);
 | 
			
		||||
  }
 | 
			
		||||
  auto literal = absl::make_unique<Literal>(
 | 
			
		||||
      ShapeUtil::MakeShape(shape().element_type(), bounds));
 | 
			
		||||
  int64 elements = ShapeUtil::ElementsIn(literal->shape());
 | 
			
		||||
  Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds));
 | 
			
		||||
  int64 elements = ShapeUtil::ElementsIn(literal.shape());
 | 
			
		||||
  if (elements == 0) {
 | 
			
		||||
    return literal;
 | 
			
		||||
  }
 | 
			
		||||
@ -1157,7 +1150,7 @@ std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
 | 
			
		||||
  bool done = false;
 | 
			
		||||
  while (!done) {
 | 
			
		||||
    const auto element = Get<NativeT>(input_indices);
 | 
			
		||||
    literal->Set<NativeT>(output_indices, element);
 | 
			
		||||
    literal.Set<NativeT>(output_indices, element);
 | 
			
		||||
 | 
			
		||||
    done = true;
 | 
			
		||||
    for (int n = 0; n < output_indices.size(); ++n) {
 | 
			
		||||
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -45,7 +45,7 @@ using absl::StrCat;
 | 
			
		||||
// Return a literal with all arrays of type FromNativeT converted to type
 | 
			
		||||
// ToNativeT in the given literal.
 | 
			
		||||
template <typename FromNativeT, typename ToNativeT>
 | 
			
		||||
std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
 | 
			
		||||
Literal ConvertType(LiteralSlice literal) {
 | 
			
		||||
  // First construct shape of the result.
 | 
			
		||||
  Shape result_shape(literal.shape());
 | 
			
		||||
  ShapeUtil::ForEachMutableSubshape(
 | 
			
		||||
@ -56,7 +56,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
 | 
			
		||||
              primitive_util::NativeToPrimitiveType<ToNativeT>());
 | 
			
		||||
        }
 | 
			
		||||
      });
 | 
			
		||||
  auto result = absl::make_unique<Literal>(result_shape);
 | 
			
		||||
  Literal result(result_shape);
 | 
			
		||||
 | 
			
		||||
  // Then copy over the data from 'literal' converting FromNativeT values to
 | 
			
		||||
  // ToNativeT values as necessary.
 | 
			
		||||
@ -67,14 +67,14 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
 | 
			
		||||
          if (subshape.element_type() ==
 | 
			
		||||
              primitive_util::NativeToPrimitiveType<FromNativeT>()) {
 | 
			
		||||
            auto src = literal.data<FromNativeT>(shape_index);
 | 
			
		||||
            auto dest = result->data<ToNativeT>(shape_index);
 | 
			
		||||
            auto dest = result.data<ToNativeT>(shape_index);
 | 
			
		||||
            for (int64 i = 0; i < src.size(); ++i) {
 | 
			
		||||
              dest[i] = static_cast<ToNativeT>(src[i]);
 | 
			
		||||
            }
 | 
			
		||||
          } else {
 | 
			
		||||
            TF_CHECK_OK(result->CopyFrom(literal,
 | 
			
		||||
                                         /*dest_shape_index=*/shape_index,
 | 
			
		||||
                                         /*src_shape_index=*/shape_index));
 | 
			
		||||
            TF_CHECK_OK(result.CopyFrom(literal,
 | 
			
		||||
                                        /*dest_shape_index=*/shape_index,
 | 
			
		||||
                                        /*src_shape_index=*/shape_index));
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      });
 | 
			
		||||
@ -83,53 +83,52 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromDimensions(
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateFromDimensions(
 | 
			
		||||
    PrimitiveType primitive_type, absl::Span<const int64> dimensions) {
 | 
			
		||||
  return Literal::CreateFromShape(
 | 
			
		||||
      ShapeUtil::MakeShape(primitive_type, dimensions));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::ConvertBF16ToF32(
 | 
			
		||||
/* static */ Literal LiteralUtil::ConvertBF16ToF32(
 | 
			
		||||
    const LiteralSlice& bf16_literal) {
 | 
			
		||||
  return ConvertType<bfloat16, float>(bf16_literal);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::ConvertF32ToBF16(
 | 
			
		||||
/* static */ Literal LiteralUtil::ConvertF32ToBF16(
 | 
			
		||||
    const LiteralSlice& f32_literal) {
 | 
			
		||||
  return ConvertType<float, bfloat16>(f32_literal);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateToken() {
 | 
			
		||||
  return absl::make_unique<Literal>(ShapeUtil::MakeTokenShape());
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateToken() {
 | 
			
		||||
  return Literal(ShapeUtil::MakeTokenShape());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) {
 | 
			
		||||
  switch (primitive_type) {
 | 
			
		||||
    case U8:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<uint8>(0));
 | 
			
		||||
      return LiteralUtil::CreateR0<uint8>(0);
 | 
			
		||||
    case U32:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<uint32>(0));
 | 
			
		||||
      return LiteralUtil::CreateR0<uint32>(0);
 | 
			
		||||
    case U64:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<uint64>(0));
 | 
			
		||||
      return LiteralUtil::CreateR0<uint64>(0);
 | 
			
		||||
    case S8:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<int8>(0));
 | 
			
		||||
      return LiteralUtil::CreateR0<int8>(0);
 | 
			
		||||
    case S32:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<int32>(0));
 | 
			
		||||
      return LiteralUtil::CreateR0<int32>(0);
 | 
			
		||||
    case S64:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<int64>(0));
 | 
			
		||||
      return LiteralUtil::CreateR0<int64>(0);
 | 
			
		||||
    case F16:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<half>(static_cast<half>(0.0f)));
 | 
			
		||||
      return LiteralUtil::CreateR0<half>(static_cast<half>(0.0f));
 | 
			
		||||
    case BF16:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f)));
 | 
			
		||||
      return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
 | 
			
		||||
    case F32:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<float>(0));
 | 
			
		||||
      return LiteralUtil::CreateR0<float>(0);
 | 
			
		||||
    case F64:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<double>(0));
 | 
			
		||||
      return LiteralUtil::CreateR0<double>(0);
 | 
			
		||||
    case C64:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<complex64>(0));
 | 
			
		||||
      return LiteralUtil::CreateR0<complex64>(0);
 | 
			
		||||
    case PRED:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<bool>(false));
 | 
			
		||||
      return LiteralUtil::CreateR0<bool>(false);
 | 
			
		||||
    case S16:
 | 
			
		||||
    case U16:
 | 
			
		||||
      LOG(FATAL) << "u16/s16 literals not yet implemented";
 | 
			
		||||
@ -145,30 +144,29 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
 | 
			
		||||
/* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) {
 | 
			
		||||
  switch (primitive_type) {
 | 
			
		||||
    case U8:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<uint8>(1));
 | 
			
		||||
      return LiteralUtil::CreateR0<uint8>(1);
 | 
			
		||||
    case U32:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<uint32>(1));
 | 
			
		||||
      return LiteralUtil::CreateR0<uint32>(1);
 | 
			
		||||
    case U64:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<uint64>(1));
 | 
			
		||||
      return LiteralUtil::CreateR0<uint64>(1);
 | 
			
		||||
    case S8:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<int8>(1));
 | 
			
		||||
      return LiteralUtil::CreateR0<int8>(1);
 | 
			
		||||
    case S32:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<int32>(1));
 | 
			
		||||
      return LiteralUtil::CreateR0<int32>(1);
 | 
			
		||||
    case S64:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<int64>(1));
 | 
			
		||||
      return LiteralUtil::CreateR0<int64>(1);
 | 
			
		||||
    case F16:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<half>(static_cast<half>(1.0f)));
 | 
			
		||||
      return LiteralUtil::CreateR0<half>(static_cast<half>(1.0f));
 | 
			
		||||
    case BF16:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f)));
 | 
			
		||||
      return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f));
 | 
			
		||||
    case F32:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<float>(1));
 | 
			
		||||
      return LiteralUtil::CreateR0<float>(1);
 | 
			
		||||
    case F64:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<double>(1));
 | 
			
		||||
      return LiteralUtil::CreateR0<double>(1);
 | 
			
		||||
    case C64:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<complex64>(1));
 | 
			
		||||
      return LiteralUtil::CreateR0<complex64>(1);
 | 
			
		||||
    case PRED:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<bool>(true));
 | 
			
		||||
      return LiteralUtil::CreateR0<bool>(true);
 | 
			
		||||
    case S16:
 | 
			
		||||
    case U16:
 | 
			
		||||
      LOG(FATAL) << "u16/s16 literals not yet implemented";
 | 
			
		||||
@ -184,42 +182,36 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
 | 
			
		||||
/* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) {
 | 
			
		||||
  switch (primitive_type) {
 | 
			
		||||
    case U8:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min()));
 | 
			
		||||
      return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min());
 | 
			
		||||
    case U32:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min()));
 | 
			
		||||
      return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min());
 | 
			
		||||
    case U64:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min()));
 | 
			
		||||
      return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min());
 | 
			
		||||
    case S8:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min()));
 | 
			
		||||
      return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min());
 | 
			
		||||
    case S32:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min()));
 | 
			
		||||
      return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min());
 | 
			
		||||
    case S64:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min()));
 | 
			
		||||
      return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min());
 | 
			
		||||
    case F32:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<float>(
 | 
			
		||||
          -std::numeric_limits<float>::infinity()));
 | 
			
		||||
      return LiteralUtil::CreateR0<float>(
 | 
			
		||||
          -std::numeric_limits<float>::infinity());
 | 
			
		||||
    case F64:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<double>(
 | 
			
		||||
          -std::numeric_limits<double>::infinity()));
 | 
			
		||||
      return LiteralUtil::CreateR0<double>(
 | 
			
		||||
          -std::numeric_limits<double>::infinity());
 | 
			
		||||
    case C64:
 | 
			
		||||
      LOG(FATAL) << "C64 element type has no minimum value";
 | 
			
		||||
    case PRED:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<bool>(false));
 | 
			
		||||
      return LiteralUtil::CreateR0<bool>(false);
 | 
			
		||||
    case S16:
 | 
			
		||||
    case U16:
 | 
			
		||||
      LOG(FATAL) << "u16/s16 literals not yet implemented";
 | 
			
		||||
    case F16:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<half>(
 | 
			
		||||
          static_cast<half>(-std::numeric_limits<float>::infinity())));
 | 
			
		||||
      return LiteralUtil::CreateR0<half>(
 | 
			
		||||
          static_cast<half>(-std::numeric_limits<float>::infinity()));
 | 
			
		||||
    case BF16:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<bfloat16>(
 | 
			
		||||
          static_cast<bfloat16>(-std::numeric_limits<float>::infinity())));
 | 
			
		||||
      return LiteralUtil::CreateR0<bfloat16>(
 | 
			
		||||
          static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
 | 
			
		||||
    case TUPLE:
 | 
			
		||||
      LOG(FATAL) << "tuple element type has no minimum value";
 | 
			
		||||
    case OPAQUE:
 | 
			
		||||
@ -232,40 +224,34 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
 | 
			
		||||
/* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) {
 | 
			
		||||
  switch (primitive_type) {
 | 
			
		||||
    case U8:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max()));
 | 
			
		||||
      return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max());
 | 
			
		||||
    case U32:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max()));
 | 
			
		||||
      return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max());
 | 
			
		||||
    case U64:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max()));
 | 
			
		||||
      return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max());
 | 
			
		||||
    case S8:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max()));
 | 
			
		||||
      return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max());
 | 
			
		||||
    case S32:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max()));
 | 
			
		||||
      return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max());
 | 
			
		||||
    case S64:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max()));
 | 
			
		||||
      return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max());
 | 
			
		||||
    case F32:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<float>(
 | 
			
		||||
          std::numeric_limits<float>::infinity()));
 | 
			
		||||
      return LiteralUtil::CreateR0<float>(
 | 
			
		||||
          std::numeric_limits<float>::infinity());
 | 
			
		||||
    case F64:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<double>(
 | 
			
		||||
          std::numeric_limits<double>::infinity()));
 | 
			
		||||
      return LiteralUtil::CreateR0<double>(
 | 
			
		||||
          std::numeric_limits<double>::infinity());
 | 
			
		||||
    case PRED:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<bool>(true));
 | 
			
		||||
      return LiteralUtil::CreateR0<bool>(true);
 | 
			
		||||
    case S16:
 | 
			
		||||
    case U16:
 | 
			
		||||
      LOG(FATAL) << "u16/s16 literals not yet implemented";
 | 
			
		||||
    case F16:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<half>(
 | 
			
		||||
          static_cast<half>(std::numeric_limits<float>::infinity())));
 | 
			
		||||
      return LiteralUtil::CreateR0<half>(
 | 
			
		||||
          static_cast<half>(std::numeric_limits<float>::infinity()));
 | 
			
		||||
    case BF16:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<bfloat16>(
 | 
			
		||||
          static_cast<bfloat16>(std::numeric_limits<float>::infinity())));
 | 
			
		||||
      return LiteralUtil::CreateR0<bfloat16>(
 | 
			
		||||
          static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
 | 
			
		||||
    case TUPLE:
 | 
			
		||||
      LOG(FATAL) << "tuple element type has no maximum value";
 | 
			
		||||
    case OPAQUE:
 | 
			
		||||
@ -275,31 +261,29 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateR1(
 | 
			
		||||
    const tensorflow::core::Bitmap& values) {
 | 
			
		||||
  auto literal = absl::make_unique<Literal>(
 | 
			
		||||
  Literal literal(
 | 
			
		||||
      ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())}));
 | 
			
		||||
  literal->PopulateR1(values);
 | 
			
		||||
  literal.PopulateR1(values);
 | 
			
		||||
  return literal;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1U8(
 | 
			
		||||
    absl::string_view value) {
 | 
			
		||||
  auto literal = absl::make_unique<Literal>(
 | 
			
		||||
      ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateR1U8(absl::string_view value) {
 | 
			
		||||
  Literal literal(ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
 | 
			
		||||
  for (int i = 0; i < value.size(); ++i) {
 | 
			
		||||
    literal->Set<uint8>({i}, value[i]);
 | 
			
		||||
    literal.Set<uint8>({i}, value[i]);
 | 
			
		||||
  }
 | 
			
		||||
  return literal;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2F32Linspace(
 | 
			
		||||
    float from, float to, int64 rows, int64 cols) {
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateR2F32Linspace(float from, float to,
 | 
			
		||||
                                                      int64 rows, int64 cols) {
 | 
			
		||||
  auto value = MakeLinspaceArray2D(from, to, rows, cols);
 | 
			
		||||
  return CreateR2FromArray2D(*value);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::ReshapeSlice(
 | 
			
		||||
/* static */ Literal LiteralUtil::ReshapeSlice(
 | 
			
		||||
    absl::Span<const int64> new_dimensions,
 | 
			
		||||
    absl::Span<const int64> minor_to_major, const LiteralSlice& literal) {
 | 
			
		||||
  int64 new_num_elements = 1;
 | 
			
		||||
@ -309,13 +293,13 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
 | 
			
		||||
  CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
 | 
			
		||||
  CHECK_EQ(new_dimensions.size(), minor_to_major.size());
 | 
			
		||||
 | 
			
		||||
  auto new_literal = absl::make_unique<Literal>(
 | 
			
		||||
  Literal new_literal(
 | 
			
		||||
      ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
 | 
			
		||||
 | 
			
		||||
  // Create a new shape with the given minor-to-major layout. This shape is used
 | 
			
		||||
  // solely for converting linear address to multi-dimensional addresses when
 | 
			
		||||
  // writing elements to the new literal.
 | 
			
		||||
  Shape shape_with_layout = new_literal->shape();
 | 
			
		||||
  Shape shape_with_layout = new_literal.shape();
 | 
			
		||||
  *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
 | 
			
		||||
 | 
			
		||||
  // Copy data into new literal, element-by-element.
 | 
			
		||||
@ -326,40 +310,40 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
 | 
			
		||||
        IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
 | 
			
		||||
    switch (literal.shape().element_type()) {
 | 
			
		||||
      case PRED:
 | 
			
		||||
        new_literal->Set<bool>(to_multi_index,
 | 
			
		||||
                               literal.Get<bool>(from_multi_index));
 | 
			
		||||
        new_literal.Set<bool>(to_multi_index,
 | 
			
		||||
                              literal.Get<bool>(from_multi_index));
 | 
			
		||||
        break;
 | 
			
		||||
      case U8:
 | 
			
		||||
        new_literal->Set<uint8>(to_multi_index,
 | 
			
		||||
                                literal.Get<uint8>(from_multi_index));
 | 
			
		||||
        new_literal.Set<uint8>(to_multi_index,
 | 
			
		||||
                               literal.Get<uint8>(from_multi_index));
 | 
			
		||||
        break;
 | 
			
		||||
      case U32:
 | 
			
		||||
        new_literal->Set<uint32>(to_multi_index,
 | 
			
		||||
                                 literal.Get<uint32>(from_multi_index));
 | 
			
		||||
        new_literal.Set<uint32>(to_multi_index,
 | 
			
		||||
                                literal.Get<uint32>(from_multi_index));
 | 
			
		||||
        break;
 | 
			
		||||
      case S32:
 | 
			
		||||
        new_literal->Set<int32>(to_multi_index,
 | 
			
		||||
                                literal.Get<int32>(from_multi_index));
 | 
			
		||||
        new_literal.Set<int32>(to_multi_index,
 | 
			
		||||
                               literal.Get<int32>(from_multi_index));
 | 
			
		||||
        break;
 | 
			
		||||
      case U64:
 | 
			
		||||
        new_literal->Set<uint64>(to_multi_index,
 | 
			
		||||
                                 literal.Get<uint64>(from_multi_index));
 | 
			
		||||
        new_literal.Set<uint64>(to_multi_index,
 | 
			
		||||
                                literal.Get<uint64>(from_multi_index));
 | 
			
		||||
        break;
 | 
			
		||||
      case S64:
 | 
			
		||||
        new_literal->Set<int64>(to_multi_index,
 | 
			
		||||
                                literal.Get<int64>(from_multi_index));
 | 
			
		||||
        new_literal.Set<int64>(to_multi_index,
 | 
			
		||||
                               literal.Get<int64>(from_multi_index));
 | 
			
		||||
        break;
 | 
			
		||||
      case F32:
 | 
			
		||||
        new_literal->Set<float>(to_multi_index,
 | 
			
		||||
                                literal.Get<float>(from_multi_index));
 | 
			
		||||
        new_literal.Set<float>(to_multi_index,
 | 
			
		||||
                               literal.Get<float>(from_multi_index));
 | 
			
		||||
        break;
 | 
			
		||||
      case F64:
 | 
			
		||||
        new_literal->Set<double>(to_multi_index,
 | 
			
		||||
                                 literal.Get<double>(from_multi_index));
 | 
			
		||||
        new_literal.Set<double>(to_multi_index,
 | 
			
		||||
                                literal.Get<double>(from_multi_index));
 | 
			
		||||
        break;
 | 
			
		||||
      case C64:
 | 
			
		||||
        new_literal->Set<complex64>(to_multi_index,
 | 
			
		||||
                                    literal.Get<complex64>(from_multi_index));
 | 
			
		||||
        new_literal.Set<complex64>(to_multi_index,
 | 
			
		||||
                                   literal.Get<complex64>(from_multi_index));
 | 
			
		||||
        break;
 | 
			
		||||
      default:
 | 
			
		||||
        LOG(FATAL) << "Unhandled primitive element type: "
 | 
			
		||||
@ -376,97 +360,82 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
 | 
			
		||||
  CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0);
 | 
			
		||||
  switch (literal.shape().element_type()) {
 | 
			
		||||
    case PRED:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>()));
 | 
			
		||||
      return LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>());
 | 
			
		||||
    // 8 bit types.
 | 
			
		||||
    case S8:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>()));
 | 
			
		||||
      return LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>());
 | 
			
		||||
    case U8:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>()));
 | 
			
		||||
      return LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>());
 | 
			
		||||
    // 16 bit types.
 | 
			
		||||
    case BF16:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<bfloat16>(
 | 
			
		||||
          literal.GetFirstElement<bfloat16>()));
 | 
			
		||||
      return LiteralUtil::CreateR0<bfloat16>(
 | 
			
		||||
          literal.GetFirstElement<bfloat16>());
 | 
			
		||||
    case F16:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>()));
 | 
			
		||||
      return LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>());
 | 
			
		||||
    case S16:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>()));
 | 
			
		||||
      return LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>());
 | 
			
		||||
    case U16:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>()));
 | 
			
		||||
      return LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>());
 | 
			
		||||
    // 32 bit types.
 | 
			
		||||
    case F32:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>()));
 | 
			
		||||
      return LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>());
 | 
			
		||||
    case S32:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>()));
 | 
			
		||||
      return LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>());
 | 
			
		||||
    case U32:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>()));
 | 
			
		||||
      return LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>());
 | 
			
		||||
    // 64 bit types.
 | 
			
		||||
    case C64:
 | 
			
		||||
      return std::move(*LiteralUtil::CreateR0<complex64>(
 | 
			
		||||
          literal.GetFirstElement<complex64>()));
 | 
			
		||||
      return LiteralUtil::CreateR0<complex64>(
 | 
			
		||||
          literal.GetFirstElement<complex64>());
 | 
			
		||||
    case F64:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>()));
 | 
			
		||||
      return LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>());
 | 
			
		||||
    case S64:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>()));
 | 
			
		||||
      return LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>());
 | 
			
		||||
    case U64:
 | 
			
		||||
      return std::move(
 | 
			
		||||
          *LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>()));
 | 
			
		||||
      return LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>());
 | 
			
		||||
    default:
 | 
			
		||||
      LOG(FATAL) << "Unhandled primitive type "
 | 
			
		||||
                 << literal.shape().element_type();
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTuple(
 | 
			
		||||
/* static */ Literal LiteralUtil::MakeTuple(
 | 
			
		||||
    absl::Span<const Literal* const> elements) {
 | 
			
		||||
  std::vector<Shape> element_shapes;
 | 
			
		||||
  for (const auto* element : elements) {
 | 
			
		||||
    element_shapes.push_back(element->shape());
 | 
			
		||||
  }
 | 
			
		||||
  auto literal =
 | 
			
		||||
      absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
 | 
			
		||||
  Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
 | 
			
		||||
  for (int i = 0; i < elements.size(); ++i) {
 | 
			
		||||
    TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
 | 
			
		||||
    TF_CHECK_OK(literal.CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
 | 
			
		||||
  }
 | 
			
		||||
  return literal;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
/* static */ Literal LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
    absl::Span<const LiteralSlice> elements) {
 | 
			
		||||
  std::vector<Shape> element_shapes;
 | 
			
		||||
  for (const auto& element : elements) {
 | 
			
		||||
    element_shapes.push_back(element.shape());
 | 
			
		||||
  }
 | 
			
		||||
  auto literal =
 | 
			
		||||
      absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
 | 
			
		||||
  Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
 | 
			
		||||
  for (int i = 0; i < elements.size(); ++i) {
 | 
			
		||||
    TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i}));
 | 
			
		||||
    TF_CHECK_OK(literal.CopyFrom(elements[i], /*dest_shape_index=*/{i}));
 | 
			
		||||
  }
 | 
			
		||||
  return literal;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleOwned(
 | 
			
		||||
    std::vector<std::unique_ptr<Literal>> elements) {
 | 
			
		||||
/* static */ Literal LiteralUtil::MakeTupleOwned(
 | 
			
		||||
    std::vector<Literal> elements) {
 | 
			
		||||
  std::vector<Shape> element_shapes;
 | 
			
		||||
  element_shapes.reserve(elements.size());
 | 
			
		||||
  for (const auto& element : elements) {
 | 
			
		||||
    element_shapes.push_back(element->shape());
 | 
			
		||||
    element_shapes.push_back(element.shape());
 | 
			
		||||
  }
 | 
			
		||||
  auto literal =
 | 
			
		||||
      absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
 | 
			
		||||
  Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
 | 
			
		||||
  for (int64 i = 0; i < elements.size(); ++i) {
 | 
			
		||||
    TF_CHECK_OK(
 | 
			
		||||
        literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i}));
 | 
			
		||||
        literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
 | 
			
		||||
  }
 | 
			
		||||
  return literal;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -69,36 +69,34 @@ class LiteralUtil {
 | 
			
		||||
  // The variants not ending with WithLayout use the default XLA layout for the
 | 
			
		||||
  // literal's linear representation in memory.
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  static std::unique_ptr<Literal> CreateR0(NativeT value);
 | 
			
		||||
  static Literal CreateR0(NativeT value);
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  static std::unique_ptr<Literal> CreateR1(absl::Span<const NativeT> values);
 | 
			
		||||
  static std::unique_ptr<Literal> CreateR1(
 | 
			
		||||
      const tensorflow::core::Bitmap& values);
 | 
			
		||||
  static Literal CreateR1(absl::Span<const NativeT> values);
 | 
			
		||||
  static Literal CreateR1(const tensorflow::core::Bitmap& values);
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  static std::unique_ptr<Literal> CreateR2(
 | 
			
		||||
  static Literal CreateR2(
 | 
			
		||||
      std::initializer_list<std::initializer_list<NativeT>> values);
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  static std::unique_ptr<Literal> CreateR2WithLayout(
 | 
			
		||||
  static Literal CreateR2WithLayout(
 | 
			
		||||
      std::initializer_list<std::initializer_list<NativeT>> values,
 | 
			
		||||
      const Layout& layout);
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  static std::unique_ptr<Literal> CreateR3(
 | 
			
		||||
      std::initializer_list<
 | 
			
		||||
          std::initializer_list<std::initializer_list<NativeT>>>
 | 
			
		||||
          values);
 | 
			
		||||
  static Literal CreateR3(std::initializer_list<
 | 
			
		||||
                          std::initializer_list<std::initializer_list<NativeT>>>
 | 
			
		||||
                              values);
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  static std::unique_ptr<Literal> CreateR3WithLayout(
 | 
			
		||||
  static Literal CreateR3WithLayout(
 | 
			
		||||
      std::initializer_list<
 | 
			
		||||
          std::initializer_list<std::initializer_list<NativeT>>>
 | 
			
		||||
          values,
 | 
			
		||||
      const Layout& layout);
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  static std::unique_ptr<Literal> CreateR4(
 | 
			
		||||
  static Literal CreateR4(
 | 
			
		||||
      std::initializer_list<std::initializer_list<
 | 
			
		||||
          std::initializer_list<std::initializer_list<NativeT>>>>
 | 
			
		||||
          values);
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  static std::unique_ptr<Literal> CreateR4WithLayout(
 | 
			
		||||
  static Literal CreateR4WithLayout(
 | 
			
		||||
      std::initializer_list<std::initializer_list<
 | 
			
		||||
          std::initializer_list<std::initializer_list<NativeT>>>>
 | 
			
		||||
          values,
 | 
			
		||||
@ -139,9 +137,10 @@ class LiteralUtil {
 | 
			
		||||
  //     [9, 10, 11]: 4.0
 | 
			
		||||
  //
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  static std::unique_ptr<Literal> CreateSparse(
 | 
			
		||||
      absl::Span<const int64> dimensions, SparseIndexArray indices,
 | 
			
		||||
      absl::Span<const NativeT> values, bool sort = true);
 | 
			
		||||
  static Literal CreateSparse(absl::Span<const int64> dimensions,
 | 
			
		||||
                              SparseIndexArray indices,
 | 
			
		||||
                              absl::Span<const NativeT> values,
 | 
			
		||||
                              bool sort = true);
 | 
			
		||||
 | 
			
		||||
  // Creates a scalar literal value zero of the given primitive type.
 | 
			
		||||
  static Literal Zero(PrimitiveType primitive_type);
 | 
			
		||||
@ -155,130 +154,120 @@ class LiteralUtil {
 | 
			
		||||
  static Literal MaxValue(PrimitiveType primitive_type);
 | 
			
		||||
  // Creates a literal of the given shape where each element is `value`.
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
 | 
			
		||||
  static Literal CreateFullWithDescendingLayout(
 | 
			
		||||
      absl::Span<const int64> dimensions, NativeT value);
 | 
			
		||||
 | 
			
		||||
  // Creates a new literal from an Array type. The variants not ending with
 | 
			
		||||
  // WithLayout use the default XLA layout for the literal's linear
 | 
			
		||||
  // representation in memory.
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values);
 | 
			
		||||
  static Literal CreateFromArray(const Array<NativeT>& values);
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  static std::unique_ptr<Literal> CreateFromArrayWithLayout(
 | 
			
		||||
      const Array<NativeT>& values, const Layout& layout);
 | 
			
		||||
  static Literal CreateFromArrayWithLayout(const Array<NativeT>& values,
 | 
			
		||||
                                           const Layout& layout);
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  static std::unique_ptr<Literal> CreateR2FromArray2D(
 | 
			
		||||
      const Array2D<NativeT>& values);
 | 
			
		||||
  static Literal CreateR2FromArray2D(const Array2D<NativeT>& values);
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  static std::unique_ptr<Literal> CreateR2FromArray2DWithLayout(
 | 
			
		||||
      const Array2D<NativeT>& values, const Layout& layout);
 | 
			
		||||
  static Literal CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
 | 
			
		||||
                                               const Layout& layout);
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  static std::unique_ptr<Literal> CreateR3FromArray3D(
 | 
			
		||||
      const Array3D<NativeT>& values);
 | 
			
		||||
  static Literal CreateR3FromArray3D(const Array3D<NativeT>& values);
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  static std::unique_ptr<Literal> CreateR3FromArray3DWithLayout(
 | 
			
		||||
      const Array3D<NativeT>& values, const Layout& layout);
 | 
			
		||||
  static Literal CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
 | 
			
		||||
                                               const Layout& layout);
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  static std::unique_ptr<Literal> CreateR4FromArray4D(
 | 
			
		||||
      const Array4D<NativeT>& values);
 | 
			
		||||
  static Literal CreateR4FromArray4D(const Array4D<NativeT>& values);
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  static std::unique_ptr<Literal> CreateR4FromArray4DWithLayout(
 | 
			
		||||
      const Array4D<NativeT>& values, const Layout& layout);
 | 
			
		||||
  static Literal CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
 | 
			
		||||
                                               const Layout& layout);
 | 
			
		||||
 | 
			
		||||
  // Creates a new vector of U8s literal value from a string.
 | 
			
		||||
  static std::unique_ptr<Literal> CreateR1U8(absl::string_view value);
 | 
			
		||||
  static Literal CreateR1U8(absl::string_view value);
 | 
			
		||||
 | 
			
		||||
  // Creates a linspace-populated literal with the given number of rows and
 | 
			
		||||
  // columns.
 | 
			
		||||
  static std::unique_ptr<Literal> CreateR2F32Linspace(float from, float to,
 | 
			
		||||
                                                      int64 rows, int64 cols);
 | 
			
		||||
  static Literal CreateR2F32Linspace(float from, float to, int64 rows,
 | 
			
		||||
                                     int64 cols);
 | 
			
		||||
 | 
			
		||||
  // Creates a literal that projects the (x, y) dimensions given in values into
 | 
			
		||||
  // the z dimension given by "projection".
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  static std::unique_ptr<Literal> CreateR3Projected(
 | 
			
		||||
  static Literal CreateR3Projected(
 | 
			
		||||
      std::initializer_list<std::initializer_list<NativeT>> values,
 | 
			
		||||
      int64 projection);
 | 
			
		||||
 | 
			
		||||
  // Creates a literal that projects the (x, y) dimensions given in values into
 | 
			
		||||
  // the z and p dimensions given.
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  static std::unique_ptr<Literal> CreateR4Projected(
 | 
			
		||||
  static Literal CreateR4Projected(
 | 
			
		||||
      std::initializer_list<std::initializer_list<NativeT>> values,
 | 
			
		||||
      int64 projection_p, int64 projection_z);
 | 
			
		||||
 | 
			
		||||
  // Returns an identity matrix (rank 2) with the given row and column count.
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  static std::unique_ptr<Literal> MakeIdentityR2(int64 size);
 | 
			
		||||
  static Literal MakeIdentityR2(int64 size);
 | 
			
		||||
 | 
			
		||||
  // Returns a tuple literal composed of given literals. Data is copied from the
 | 
			
		||||
  // given elements into the returned literal.
 | 
			
		||||
  static std::unique_ptr<Literal> MakeTuple(
 | 
			
		||||
      absl::Span<const Literal* const> elements);
 | 
			
		||||
  static Literal MakeTuple(absl::Span<const Literal* const> elements);
 | 
			
		||||
 | 
			
		||||
  static std::unique_ptr<Literal> MakeTupleFromSlices(
 | 
			
		||||
      absl::Span<const LiteralSlice> elements);
 | 
			
		||||
  static Literal MakeTupleFromSlices(absl::Span<const LiteralSlice> elements);
 | 
			
		||||
 | 
			
		||||
  // As above, but intended to be invoked with move semantics; i.e.
 | 
			
		||||
  //
 | 
			
		||||
  //  std::vector<std::unique_ptr<Literal>> elements = ...;
 | 
			
		||||
  //  std::vector<Literal> elements = ...;
 | 
			
		||||
  //  auto result = LiteralUtil::MakeTupleOwned(std::move(elements));
 | 
			
		||||
  //
 | 
			
		||||
  // This would have been declared as an overload, but there is ambiguity
 | 
			
		||||
  // in invocation between the above signature and this one.
 | 
			
		||||
  static std::unique_ptr<Literal> MakeTupleOwned(
 | 
			
		||||
      std::vector<std::unique_ptr<Literal>> elements);
 | 
			
		||||
  static Literal MakeTupleOwned(std::vector<Literal> elements);
 | 
			
		||||
 | 
			
		||||
  // This overload lets you pass a braced list of unique_ptr<Literal>s to
 | 
			
		||||
  // This overload lets you pass a braced list of Literals to
 | 
			
		||||
  // MakeTupleOwned:
 | 
			
		||||
  //
 | 
			
		||||
  //   LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...).
 | 
			
		||||
  //
 | 
			
		||||
  // Simply relying on the MakeTupleOwned(std::vector<unique_ptr<Literal>>)
 | 
			
		||||
  // Simply relying on the MakeTupleOwned(std::vector<Literal>)
 | 
			
		||||
  // overload doesn't work because std::initializer_list's elements are always
 | 
			
		||||
  // const.
 | 
			
		||||
  //
 | 
			
		||||
  // The arguments to this function must all be unique_ptr<Literal>.
 | 
			
		||||
  // The arguments to this function must all be Literal.
 | 
			
		||||
  template <typename... Ts>
 | 
			
		||||
  static std::unique_ptr<Literal> MakeTupleOwned(
 | 
			
		||||
      std::unique_ptr<Ts>... elements) {
 | 
			
		||||
    std::array<std::unique_ptr<Literal>, sizeof...(Ts)> arr{
 | 
			
		||||
        std::move(elements)...};
 | 
			
		||||
    std::vector<std::unique_ptr<Literal>> v;
 | 
			
		||||
  static Literal MakeTupleOwned(Ts... elements) {
 | 
			
		||||
    std::array<Literal, sizeof...(Ts)> arr{std::move(elements)...};
 | 
			
		||||
    std::vector<Literal> v;
 | 
			
		||||
    v.insert(v.begin(), std::make_move_iterator(arr.begin()),
 | 
			
		||||
             std::make_move_iterator(arr.end()));
 | 
			
		||||
    return MakeTupleOwned(std::move(v));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Create a constant token literal. Token types have no value.
 | 
			
		||||
  static std::unique_ptr<Literal> CreateToken();
 | 
			
		||||
  static Literal CreateToken();
 | 
			
		||||
 | 
			
		||||
  // Creates a new Literal object with its values havings the primitive_type
 | 
			
		||||
  // type, and with dimensions defined by the dimensions parameter.
 | 
			
		||||
  // The content of the literal values is the default value of the primitive
 | 
			
		||||
  // type of literal itself (0 for numeric types, and false for predicates).
 | 
			
		||||
  static std::unique_ptr<Literal> CreateFromDimensions(
 | 
			
		||||
      PrimitiveType primitive_type, absl::Span<const int64> dimensions);
 | 
			
		||||
  static Literal CreateFromDimensions(PrimitiveType primitive_type,
 | 
			
		||||
                                      absl::Span<const int64> dimensions);
 | 
			
		||||
 | 
			
		||||
  // If the given literal's data type is bfloat16, converts it to a float
 | 
			
		||||
  // literal; otherwise, returns a copy of it. If the literal is a tuple,
 | 
			
		||||
  // recursively converts its elements.
 | 
			
		||||
  static std::unique_ptr<Literal> ConvertBF16ToF32(
 | 
			
		||||
      const LiteralSlice& bf16_literal);
 | 
			
		||||
  static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal);
 | 
			
		||||
 | 
			
		||||
  // If the given literal's data type is float, converts it to a bfloat16
 | 
			
		||||
  // literal; otherwise, returns a copy of it. If the literal is a tuple,
 | 
			
		||||
  // recursively converts its elements.
 | 
			
		||||
  static std::unique_ptr<Literal> ConvertF32ToBF16(
 | 
			
		||||
      const LiteralSlice& f32_literal);
 | 
			
		||||
  static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal);
 | 
			
		||||
 | 
			
		||||
  // Creates a literal with a new shape with the given new dimensions using the
 | 
			
		||||
  // data in the given input literal. For reshaping purposes the (flat) data
 | 
			
		||||
  // buffer of the input literal is assumed to have the given minor_to_major
 | 
			
		||||
  // layout order.
 | 
			
		||||
  static std::unique_ptr<Literal> ReshapeSlice(
 | 
			
		||||
      absl::Span<const int64> new_dimensions,
 | 
			
		||||
      absl::Span<const int64> minor_to_major, const LiteralSlice& literal);
 | 
			
		||||
  static Literal ReshapeSlice(absl::Span<const int64> new_dimensions,
 | 
			
		||||
                              absl::Span<const int64> minor_to_major,
 | 
			
		||||
                              const LiteralSlice& literal);
 | 
			
		||||
 | 
			
		||||
  // Creates a literal with the supplied shape, and uses the provided value
 | 
			
		||||
  // generator to populate the literal's values.
 | 
			
		||||
@ -286,7 +275,7 @@ class LiteralUtil {
 | 
			
		||||
  template <
 | 
			
		||||
      PrimitiveType type,
 | 
			
		||||
      typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
 | 
			
		||||
  static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
 | 
			
		||||
  static StatusOr<Literal> CreateRandomLiteral(
 | 
			
		||||
      const Shape& shape,
 | 
			
		||||
      const std::function<T(absl::Span<const int64>)>& generator);
 | 
			
		||||
 | 
			
		||||
@ -297,8 +286,8 @@ class LiteralUtil {
 | 
			
		||||
  template <
 | 
			
		||||
      PrimitiveType type, typename E,
 | 
			
		||||
      typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
 | 
			
		||||
  static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
 | 
			
		||||
      const Shape& shape, E* engine, T mean, T stddev);
 | 
			
		||||
  static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, E* engine,
 | 
			
		||||
                                               T mean, T stddev);
 | 
			
		||||
 | 
			
		||||
  // Creates a literal with the supplied shape, and initializes the literal
 | 
			
		||||
  // values using a normal distribution with given mean and stddev standard
 | 
			
		||||
@ -307,8 +296,8 @@ class LiteralUtil {
 | 
			
		||||
  template <
 | 
			
		||||
      PrimitiveType type,
 | 
			
		||||
      typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
 | 
			
		||||
  static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
 | 
			
		||||
      const Shape& shape, T mean, T stddev);
 | 
			
		||||
  static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, T mean,
 | 
			
		||||
                                               T stddev);
 | 
			
		||||
 | 
			
		||||
  //
 | 
			
		||||
  // End of factory methods.
 | 
			
		||||
@ -322,44 +311,43 @@ class LiteralUtil {
 | 
			
		||||
std::ostream& operator<<(std::ostream& out, const Literal& literal);
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR0(NativeT value) {
 | 
			
		||||
  auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShape(
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateR0(NativeT value) {
 | 
			
		||||
  Literal literal(ShapeUtil::MakeShape(
 | 
			
		||||
      primitive_util::NativeToPrimitiveType<NativeT>(), {}));
 | 
			
		||||
  literal->Set({}, value);
 | 
			
		||||
  literal.Set({}, value);
 | 
			
		||||
  return literal;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
 | 
			
		||||
    absl::Span<const NativeT> values) {
 | 
			
		||||
  auto literal = absl::make_unique<Literal>(
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateR1(absl::Span<const NativeT> values) {
 | 
			
		||||
  Literal literal(
 | 
			
		||||
      ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
 | 
			
		||||
                           {static_cast<int64>(values.size())}));
 | 
			
		||||
  literal->PopulateR1(values);
 | 
			
		||||
  literal.PopulateR1(values);
 | 
			
		||||
  return literal;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2WithLayout(
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateR2WithLayout(
 | 
			
		||||
    std::initializer_list<std::initializer_list<NativeT>> values,
 | 
			
		||||
    const Layout& layout) {
 | 
			
		||||
  auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShapeWithLayout(
 | 
			
		||||
  Literal literal(ShapeUtil::MakeShapeWithLayout(
 | 
			
		||||
      primitive_util::NativeToPrimitiveType<NativeT>(),
 | 
			
		||||
      {static_cast<int64>(values.size()),
 | 
			
		||||
       static_cast<int64>(values.begin()->size())},
 | 
			
		||||
      AsInt64Slice(layout.minor_to_major())));
 | 
			
		||||
  literal->PopulateR2(values);
 | 
			
		||||
  literal.PopulateR2(values);
 | 
			
		||||
  return literal;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2(
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateR2(
 | 
			
		||||
    std::initializer_list<std::initializer_list<NativeT>> values) {
 | 
			
		||||
  return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3WithLayout(
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateR3WithLayout(
 | 
			
		||||
    std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
 | 
			
		||||
        values,
 | 
			
		||||
    const Layout& layout) {
 | 
			
		||||
@ -384,14 +372,14 @@ template <typename NativeT>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3(
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateR3(
 | 
			
		||||
    std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
 | 
			
		||||
        values) {
 | 
			
		||||
  return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4WithLayout(
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateR4WithLayout(
 | 
			
		||||
    std::initializer_list<std::initializer_list<
 | 
			
		||||
        std::initializer_list<std::initializer_list<NativeT>>>>
 | 
			
		||||
        values,
 | 
			
		||||
@ -422,23 +410,22 @@ template <typename NativeT>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateSparse(
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateSparse(
 | 
			
		||||
    absl::Span<const int64> dimensions, SparseIndexArray indices,
 | 
			
		||||
    absl::Span<const NativeT> values, bool sort) {
 | 
			
		||||
  int64 num_elements = values.size();
 | 
			
		||||
  int64 rank = dimensions.size();
 | 
			
		||||
  CHECK_EQ(num_elements, indices.index_count());
 | 
			
		||||
  CHECK_EQ(rank, indices.rank());
 | 
			
		||||
  auto literal =
 | 
			
		||||
      absl::make_unique<Literal>(ShapeUtil::MakeShapeWithSparseLayout(
 | 
			
		||||
          primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
 | 
			
		||||
          indices.max_indices()));
 | 
			
		||||
  literal->PopulateSparse(indices, values, sort);
 | 
			
		||||
  Literal literal(ShapeUtil::MakeShapeWithSparseLayout(
 | 
			
		||||
      primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
 | 
			
		||||
      indices.max_indices()));
 | 
			
		||||
  literal.PopulateSparse(indices, values, sort);
 | 
			
		||||
  return literal;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4(
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateR4(
 | 
			
		||||
    std::initializer_list<std::initializer_list<
 | 
			
		||||
        std::initializer_list<std::initializer_list<NativeT>>>>
 | 
			
		||||
        values) {
 | 
			
		||||
@ -446,50 +433,48 @@ template <typename NativeT>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArrayWithLayout(
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateFromArrayWithLayout(
 | 
			
		||||
    const Array<NativeT>& values, const Layout& layout) {
 | 
			
		||||
  auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShapeWithLayout(
 | 
			
		||||
  Literal literal(ShapeUtil::MakeShapeWithLayout(
 | 
			
		||||
      primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
 | 
			
		||||
      AsInt64Slice(layout.minor_to_major())));
 | 
			
		||||
  literal->PopulateFromArray(values);
 | 
			
		||||
  literal.PopulateFromArray(values);
 | 
			
		||||
  return literal;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArray(
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateFromArray(
 | 
			
		||||
    const Array<NativeT>& values) {
 | 
			
		||||
  return CreateFromArrayWithLayout(
 | 
			
		||||
      values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
/* static */ std::unique_ptr<Literal>
 | 
			
		||||
LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
 | 
			
		||||
                                           const Layout& layout) {
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout(
 | 
			
		||||
    const Array2D<NativeT>& values, const Layout& layout) {
 | 
			
		||||
  return CreateFromArrayWithLayout(values, layout);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2FromArray2D(
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateR2FromArray2D(
 | 
			
		||||
    const Array2D<NativeT>& values) {
 | 
			
		||||
  return CreateFromArray(values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
/* static */ std::unique_ptr<Literal>
 | 
			
		||||
LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
 | 
			
		||||
                                           const Layout& layout) {
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout(
 | 
			
		||||
    const Array3D<NativeT>& values, const Layout& layout) {
 | 
			
		||||
  return CreateFromArrayWithLayout(values, layout);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3FromArray3D(
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateR3FromArray3D(
 | 
			
		||||
    const Array3D<NativeT>& values) {
 | 
			
		||||
  return CreateFromArray(values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3Projected(
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateR3Projected(
 | 
			
		||||
    std::initializer_list<std::initializer_list<NativeT>> values,
 | 
			
		||||
    int64 projection) {
 | 
			
		||||
  int64 dim0_size = projection;
 | 
			
		||||
@ -514,7 +499,7 @@ template <typename NativeT>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4Projected(
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateR4Projected(
 | 
			
		||||
    std::initializer_list<std::initializer_list<NativeT>> values,
 | 
			
		||||
    int64 projection_p, int64 projection_z) {
 | 
			
		||||
  int64 dim0_size = projection_p;
 | 
			
		||||
@ -542,21 +527,20 @@ template <typename NativeT>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4FromArray4D(
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateR4FromArray4D(
 | 
			
		||||
    const Array4D<NativeT>& values) {
 | 
			
		||||
  return CreateFromArray(values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
/* static */ std::unique_ptr<Literal>
 | 
			
		||||
LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
 | 
			
		||||
                                           const Layout& layout) {
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout(
 | 
			
		||||
    const Array4D<NativeT>& values, const Layout& layout) {
 | 
			
		||||
  return CreateFromArrayWithLayout(values, layout);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Returns an identity matrix (rank 2) with the given row and column count.
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
/* static */ std::unique_ptr<Literal> LiteralUtil::MakeIdentityR2(int64 size) {
 | 
			
		||||
/* static */ Literal LiteralUtil::MakeIdentityR2(int64 size) {
 | 
			
		||||
  Array2D<NativeT> array(size, size, 0);
 | 
			
		||||
  for (int64 i = 0; i < size; ++i) {
 | 
			
		||||
    array(i, i) = 1;
 | 
			
		||||
@ -565,33 +549,29 @@ template <typename NativeT>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename NativeT>
 | 
			
		||||
/* static */ std::unique_ptr<Literal>
 | 
			
		||||
LiteralUtil::CreateFullWithDescendingLayout(absl::Span<const int64> dimensions,
 | 
			
		||||
                                            NativeT value) {
 | 
			
		||||
  auto literal =
 | 
			
		||||
      absl::make_unique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout(
 | 
			
		||||
          primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
 | 
			
		||||
  literal->PopulateWithValue(value);
 | 
			
		||||
/* static */ Literal LiteralUtil::CreateFullWithDescendingLayout(
 | 
			
		||||
    absl::Span<const int64> dimensions, NativeT value) {
 | 
			
		||||
  Literal literal(ShapeUtil::MakeShapeWithDescendingLayout(
 | 
			
		||||
      primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
 | 
			
		||||
  literal.PopulateWithValue(value);
 | 
			
		||||
  return literal;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <PrimitiveType type, typename T>
 | 
			
		||||
/* static */ StatusOr<std::unique_ptr<Literal>>
 | 
			
		||||
LiteralUtil::CreateRandomLiteral(
 | 
			
		||||
/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
 | 
			
		||||
    const Shape& shape,
 | 
			
		||||
    const std::function<T(absl::Span<const int64>)>& generator) {
 | 
			
		||||
  using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
 | 
			
		||||
  TF_RET_CHECK(shape.element_type() == type);
 | 
			
		||||
  auto literal = absl::make_unique<Literal>(shape);
 | 
			
		||||
  TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
 | 
			
		||||
  Literal literal(shape);
 | 
			
		||||
  TF_RETURN_IF_ERROR(literal.Populate<NativeT>(
 | 
			
		||||
      [&](absl::Span<const int64> indexes) { return generator(indexes); }));
 | 
			
		||||
  return std::move(literal);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <PrimitiveType type, typename E, typename T>
 | 
			
		||||
/* static */ StatusOr<std::unique_ptr<Literal>>
 | 
			
		||||
LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
 | 
			
		||||
                                 T stddev) {
 | 
			
		||||
/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
 | 
			
		||||
    const Shape& shape, E* engine, T mean, T stddev) {
 | 
			
		||||
  using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
 | 
			
		||||
  std::normal_distribution<NativeT> generator(mean, stddev);
 | 
			
		||||
  return CreateRandomLiteral<type, NativeT>(
 | 
			
		||||
@ -600,8 +580,8 @@ LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <PrimitiveType type, typename T>
 | 
			
		||||
/* static */ StatusOr<std::unique_ptr<Literal>>
 | 
			
		||||
LiteralUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) {
 | 
			
		||||
/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
 | 
			
		||||
    const Shape& shape, T mean, T stddev) {
 | 
			
		||||
  std::minstd_rand0 engine;
 | 
			
		||||
  return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -39,8 +39,8 @@ PackedLiteralReader::PackedLiteralReader(tensorflow::RandomAccessFile* file)
 | 
			
		||||
 | 
			
		||||
PackedLiteralReader::~PackedLiteralReader() { delete file_; }
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
 | 
			
		||||
    const Shape& shape, const Layout* layout) {
 | 
			
		||||
StatusOr<Literal> PackedLiteralReader::Read(const Shape& shape,
 | 
			
		||||
                                            const Layout* layout) {
 | 
			
		||||
  VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape)
 | 
			
		||||
          << " layout: "
 | 
			
		||||
          << (layout == nullptr ? "<none>" : layout->ShortDebugString());
 | 
			
		||||
@ -57,11 +57,11 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
 | 
			
		||||
        PrimitiveType_Name(shape.element_type()));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto result = absl::make_unique<Literal>(literal_shape);
 | 
			
		||||
  result->PopulateWithValue(std::numeric_limits<float>::quiet_NaN());
 | 
			
		||||
  Literal result(literal_shape);
 | 
			
		||||
  result.PopulateWithValue(std::numeric_limits<float>::quiet_NaN());
 | 
			
		||||
 | 
			
		||||
  int64 elements = ShapeUtil::ElementsIn(shape);
 | 
			
		||||
  absl::Span<const float> field = result->data<float>();
 | 
			
		||||
  absl::Span<const float> field = result.data<float>();
 | 
			
		||||
  char* data = absl::bit_cast<char*>(field.data());
 | 
			
		||||
  uint64 bytes = elements * sizeof(float);
 | 
			
		||||
  absl::string_view sp;
 | 
			
		||||
 | 
			
		||||
@ -41,8 +41,7 @@ class PackedLiteralReader {
 | 
			
		||||
  //
 | 
			
		||||
  // Layout is optional. If it is not provided, no layout is set on the literal
 | 
			
		||||
  // that is produced.
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> Read(const Shape& shape,
 | 
			
		||||
                                          const Layout* layout = nullptr);
 | 
			
		||||
  StatusOr<Literal> Read(const Shape& shape, const Layout* layout = nullptr);
 | 
			
		||||
 | 
			
		||||
  // Returns whether the input file has been fully exhausted; i.e. all available
 | 
			
		||||
  // packed literals have been read and we're at the end of the file.
 | 
			
		||||
 | 
			
		||||
@ -81,8 +81,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal,
 | 
			
		||||
  return client->TransferToInfeedLocal(literal, device_ordinal);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocalReplica(
 | 
			
		||||
    const Shape& shape, int replica_number) {
 | 
			
		||||
StatusOr<Literal> TransferFromOutfeedLocalReplica(const Shape& shape,
 | 
			
		||||
                                                  int replica_number) {
 | 
			
		||||
  VLOG(1) << "Outfeeding literal from replica number: " << replica_number
 | 
			
		||||
          << " shape: " << shape;
 | 
			
		||||
  LocalClient* client = GetOrCreateLocalClient();
 | 
			
		||||
@ -141,9 +141,8 @@ StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
 | 
			
		||||
  LocalClient* client = GetOrCreateLocalClient();
 | 
			
		||||
  StatusOr<ScopedShapedBuffer> buf = [&] {
 | 
			
		||||
    if (shape_with_layout) {
 | 
			
		||||
      std::unique_ptr<Literal> relaid =
 | 
			
		||||
          argument.Relayout(shape_with_layout.value());
 | 
			
		||||
      return ToBuffer(client, /*device_ordinal=*/0, *relaid);
 | 
			
		||||
      Literal relaid = argument.Relayout(shape_with_layout.value());
 | 
			
		||||
      return ToBuffer(client, /*device_ordinal=*/0, relaid);
 | 
			
		||||
    }
 | 
			
		||||
    return ToBuffer(client, /*device_ordinal=*/0, argument);
 | 
			
		||||
  }();
 | 
			
		||||
@ -151,7 +150,7 @@ StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
 | 
			
		||||
  return new LocalShapedBuffer(std::move(buf).ValueOrDie());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> LocalShapedBuffer::ToLiteral() const {
 | 
			
		||||
StatusOr<Literal> LocalShapedBuffer::ToLiteral() const {
 | 
			
		||||
  LocalClient* client = GetOrCreateLocalClient();
 | 
			
		||||
  return client->ShapedBufferToLiteral(*shaped_buffer());
 | 
			
		||||
}
 | 
			
		||||
@ -160,7 +159,7 @@ CompiledLocalComputation::CompiledLocalComputation(
 | 
			
		||||
    std::unique_ptr<LocalExecutable> executable)
 | 
			
		||||
    : executable_(std::move(executable)) {}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
 | 
			
		||||
StatusOr<Literal> CompiledLocalComputation::Execute(
 | 
			
		||||
    const std::vector<Literal>& arguments,
 | 
			
		||||
    const std::vector<absl::optional<Shape>>& shapes_with_layout) {
 | 
			
		||||
  LocalClient* client = GetOrCreateLocalClient();
 | 
			
		||||
@ -169,7 +168,7 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
 | 
			
		||||
 | 
			
		||||
  // Each replica populates a StatusOr result, but only replica zero actually
 | 
			
		||||
  // retrieves its literal value.
 | 
			
		||||
  std::vector<StatusOr<std::unique_ptr<Literal>>> results(GetReplicaCount());
 | 
			
		||||
  std::vector<StatusOr<Literal>> results(GetReplicaCount());
 | 
			
		||||
  {
 | 
			
		||||
    tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun",
 | 
			
		||||
                                        GetReplicaCount());
 | 
			
		||||
@ -198,9 +197,8 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
 | 
			
		||||
 | 
			
		||||
              StatusOr<ScopedShapedBuffer> pushed;
 | 
			
		||||
              if (shape_with_layout) {
 | 
			
		||||
                std::unique_ptr<Literal> relaid =
 | 
			
		||||
                    argument.Relayout(shape_with_layout.value());
 | 
			
		||||
                pushed = ToBuffer(client, device_ordinal, *relaid);
 | 
			
		||||
                Literal relaid = argument.Relayout(shape_with_layout.value());
 | 
			
		||||
                pushed = ToBuffer(client, device_ordinal, relaid);
 | 
			
		||||
              } else {
 | 
			
		||||
                pushed = ToBuffer(client, device_ordinal, argument);
 | 
			
		||||
              }
 | 
			
		||||
 | 
			
		||||
@ -51,8 +51,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number);
 | 
			
		||||
// Transfers a literal of the given shape from the outfeed of the given replica.
 | 
			
		||||
//
 | 
			
		||||
// The replica number is resolved to an appropriate device ordinal.
 | 
			
		||||
StatusOr<std::unique_ptr<Literal> > TransferFromOutfeedLocalReplica(
 | 
			
		||||
    const Shape& shape, int replica_number);
 | 
			
		||||
StatusOr<Literal> TransferFromOutfeedLocalReplica(const Shape& shape,
 | 
			
		||||
                                                  int replica_number);
 | 
			
		||||
 | 
			
		||||
// Wraps a ScopedShapedBuffer produced by copying a literal "to
 | 
			
		||||
// device," i.e. copying a literal to a scoped buffer via the local
 | 
			
		||||
@ -65,7 +65,7 @@ class LocalShapedBuffer {
 | 
			
		||||
  LocalShapedBuffer(ScopedShapedBuffer shaped_buffer);
 | 
			
		||||
  const ScopedShapedBuffer* shaped_buffer() const;
 | 
			
		||||
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal> > ToLiteral() const;
 | 
			
		||||
  StatusOr<Literal> ToLiteral() const;
 | 
			
		||||
 | 
			
		||||
  // Transfers ownership of the encapsulated ShapedBuffer to the caller,
 | 
			
		||||
  // analogous to std::unique_ptr::release().
 | 
			
		||||
@ -117,7 +117,7 @@ class CompiledLocalComputation {
 | 
			
		||||
  // with optionally-specified argument layouts. The literals will be
 | 
			
		||||
  // re-laid out according to the corresponding elements of
 | 
			
		||||
  // shapes_with_layout.
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal> > Execute(
 | 
			
		||||
  StatusOr<Literal> Execute(
 | 
			
		||||
      const std::vector<Literal>& arguments,
 | 
			
		||||
      const std::vector<absl::optional<Shape> >& shapes_with_layout);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -216,9 +216,9 @@ tensorflow::ImportNumpy();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
%typemap(out) StatusOr< std::unique_ptr<Literal> > {
 | 
			
		||||
%typemap(out) StatusOr<Literal> {
 | 
			
		||||
  if ($1.ok()) {
 | 
			
		||||
    std::unique_ptr<Literal> value = $1.ConsumeValueOrDie();
 | 
			
		||||
    Literal value = $1.ConsumeValueOrDie();
 | 
			
		||||
    $result = numpy::PyObjectFromXlaLiteral(*value);
 | 
			
		||||
  } else {
 | 
			
		||||
    PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
 | 
			
		||||
@ -346,25 +346,25 @@ tensorflow::ImportNumpy();
 | 
			
		||||
 | 
			
		||||
// Literal
 | 
			
		||||
 | 
			
		||||
%typemap(in) const Literal& (StatusOr< std::unique_ptr<Literal> > literal_status) {
 | 
			
		||||
%typemap(in) const Literal& (StatusOr<Literal> literal_status) {
 | 
			
		||||
  literal_status = numpy::XlaLiteralFromPyObject($input);
 | 
			
		||||
  if (!literal_status.ok()) {
 | 
			
		||||
    PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
 | 
			
		||||
    SWIG_fail;
 | 
			
		||||
  }
 | 
			
		||||
  $1 = literal_status.ValueOrDie().get();
 | 
			
		||||
  $1 = &literal_status.ValueOrDie();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
%typemap(out) std::unique_ptr<Literal> {
 | 
			
		||||
%typemap(out) Literal {
 | 
			
		||||
  $result = numpy::PyObjectFromXlaLiteral(*$1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
%typemap(out) StatusOr< std::unique_ptr<Literal> > {
 | 
			
		||||
%typemap(out) StatusOr<Literal> {
 | 
			
		||||
  if (!$1.ok()) {
 | 
			
		||||
    PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
 | 
			
		||||
    SWIG_fail;
 | 
			
		||||
  }
 | 
			
		||||
  $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie());
 | 
			
		||||
  $result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
%typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
 | 
			
		||||
@ -375,13 +375,13 @@ tensorflow::ImportNumpy();
 | 
			
		||||
  const int size = PySequence_Size($input);
 | 
			
		||||
  for (int i = 0; i < size; ++i) {
 | 
			
		||||
    PyObject* o = PySequence_GetItem($input, i);
 | 
			
		||||
    StatusOr< std::unique_ptr<Literal> > literal_status = numpy::XlaLiteralFromPyObject(o);
 | 
			
		||||
    StatusOr<Literal> literal_status = numpy::XlaLiteralFromPyObject(o);
 | 
			
		||||
    if (!literal_status.ok()) {
 | 
			
		||||
      PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
 | 
			
		||||
      Py_DECREF(o);
 | 
			
		||||
      SWIG_fail;
 | 
			
		||||
    }
 | 
			
		||||
    temps.push_back(std::move(*literal_status.ConsumeValueOrDie()));
 | 
			
		||||
    temps.push_back(literal_status.ConsumeValueOrDie());
 | 
			
		||||
    Py_DECREF(o);
 | 
			
		||||
  }
 | 
			
		||||
  $1 = &temps;
 | 
			
		||||
 | 
			
		||||
@ -368,10 +368,10 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) {
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
 | 
			
		||||
StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o) {
 | 
			
		||||
  if (PyTuple_Check(o)) {
 | 
			
		||||
    int num_elements = PyTuple_Size(o);
 | 
			
		||||
    std::vector<std::unique_ptr<Literal>> elements;
 | 
			
		||||
    std::vector<Literal> elements;
 | 
			
		||||
    elements.reserve(num_elements);
 | 
			
		||||
    for (int i = 0; i < num_elements; i++) {
 | 
			
		||||
      PyObject* element = PyTuple_GetItem(o, i);
 | 
			
		||||
@ -389,8 +389,7 @@ StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
 | 
			
		||||
    int np_type = PyArray_TYPE(py_array);
 | 
			
		||||
    auto literal = LiteralUtil::CreateFromDimensions(
 | 
			
		||||
        NumpyTypeToPrimitiveType(np_type), dimensions);
 | 
			
		||||
    TF_RETURN_IF_ERROR(
 | 
			
		||||
        CopyNumpyArrayToLiteral(np_type, py_array, literal.get()));
 | 
			
		||||
    TF_RETURN_IF_ERROR(CopyNumpyArrayToLiteral(np_type, py_array, &literal));
 | 
			
		||||
    return std::move(literal);
 | 
			
		||||
  } else {
 | 
			
		||||
    return InvalidArgument(
 | 
			
		||||
 | 
			
		||||
@ -82,7 +82,7 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal);
 | 
			
		||||
// To avoid transferring ownership of the data buffers that underlie
 | 
			
		||||
// PyArrays and XLA literals, this function makes deep copies of all
 | 
			
		||||
// array data.
 | 
			
		||||
StatusOr<std::unique_ptr<Literal> > XlaLiteralFromPyObject(PyObject* o);
 | 
			
		||||
StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o);
 | 
			
		||||
 | 
			
		||||
// The following functions copy array data from the buffers underlying Numpy
 | 
			
		||||
// ndarrays into those underlying XLA literals, and vice versa.
 | 
			
		||||
 | 
			
		||||
@ -529,13 +529,13 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ordered_input_dimensions[0] =
 | 
			
		||||
      lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0));
 | 
			
		||||
      lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0));
 | 
			
		||||
  ordered_input_dimensions[1] =
 | 
			
		||||
      lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1));
 | 
			
		||||
      lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1));
 | 
			
		||||
  ordered_kernel_dimensions[0] =
 | 
			
		||||
      rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0));
 | 
			
		||||
      rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0));
 | 
			
		||||
  ordered_kernel_dimensions[1] =
 | 
			
		||||
      rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1));
 | 
			
		||||
      rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1));
 | 
			
		||||
 | 
			
		||||
  std::vector<std::pair<int64, int64>> paddings =
 | 
			
		||||
      MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
 | 
			
		||||
@ -546,7 +546,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
 | 
			
		||||
 | 
			
		||||
  WindowDimension dim;
 | 
			
		||||
  dim.set_size(
 | 
			
		||||
      rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)));
 | 
			
		||||
      rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)));
 | 
			
		||||
  dim.set_stride(kernel_stride.first);
 | 
			
		||||
  dim.set_padding_low(paddings[0].first);
 | 
			
		||||
  dim.set_padding_high(paddings[0].second);
 | 
			
		||||
@ -556,7 +556,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
 | 
			
		||||
 | 
			
		||||
  WindowDimension dim2;
 | 
			
		||||
  dim2.set_size(
 | 
			
		||||
      rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)));
 | 
			
		||||
      rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)));
 | 
			
		||||
  dim2.set_stride(kernel_stride.second);
 | 
			
		||||
  dim2.set_padding_low(paddings[1].first);
 | 
			
		||||
  dim2.set_padding_high(paddings[1].second);
 | 
			
		||||
@ -565,7 +565,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
 | 
			
		||||
  *window.add_dimensions() = dim2;
 | 
			
		||||
 | 
			
		||||
  const Shape& shape = ShapeInference::InferConvolveShape(
 | 
			
		||||
                           lhs_literal->shape(), rhs_literal->shape(),
 | 
			
		||||
                           lhs_literal.shape(), rhs_literal.shape(),
 | 
			
		||||
                           /*feature_group_count=*/1, window, dnums)
 | 
			
		||||
                           .ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
@ -585,18 +585,18 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
 | 
			
		||||
  auto computation = module.AddEntryComputation(b.Build());
 | 
			
		||||
 | 
			
		||||
  HloEvaluator evaluator;
 | 
			
		||||
  std::unique_ptr<Literal> result_literal =
 | 
			
		||||
  Literal result_literal =
 | 
			
		||||
      evaluator.Evaluate<const Literal*>(*computation, {}).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4);
 | 
			
		||||
  CHECK_EQ(ShapeUtil::Rank(result_literal.shape()), 4);
 | 
			
		||||
  auto result =
 | 
			
		||||
      absl::make_unique<Array4D<float>>(result_literal->shape().dimensions(0),
 | 
			
		||||
                                        result_literal->shape().dimensions(1),
 | 
			
		||||
                                        result_literal->shape().dimensions(2),
 | 
			
		||||
                                        result_literal->shape().dimensions(3));
 | 
			
		||||
      absl::make_unique<Array4D<float>>(result_literal.shape().dimensions(0),
 | 
			
		||||
                                        result_literal.shape().dimensions(1),
 | 
			
		||||
                                        result_literal.shape().dimensions(2),
 | 
			
		||||
                                        result_literal.shape().dimensions(3));
 | 
			
		||||
 | 
			
		||||
  result->Each([&](absl::Span<const int64> indices, float* value) {
 | 
			
		||||
    *value = result_literal->Get<float>(indices);
 | 
			
		||||
    *value = result_literal.Get<float>(indices);
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  return result;
 | 
			
		||||
 | 
			
		||||
@ -55,7 +55,7 @@ TEST_F(ReferenceUtilTest, TransposeArray2D) {
 | 
			
		||||
  auto result = ReferenceUtil::TransposeArray2D(*matrix_);
 | 
			
		||||
  auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
 | 
			
		||||
  LiteralTestUtil::ExpectR2Near<float>({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}},
 | 
			
		||||
                                       *actual_literal, ErrorSpec(0.0001));
 | 
			
		||||
                                       actual_literal, ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ReferenceUtilTest, MatmulArray2D) {
 | 
			
		||||
@ -67,14 +67,14 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) {
 | 
			
		||||
  auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs);
 | 
			
		||||
  auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
 | 
			
		||||
  LiteralTestUtil::ExpectR2Near<float>({{58.f, 64.f}, {139.f, 154.f}},
 | 
			
		||||
                                       *actual_literal, ErrorSpec(0.0001));
 | 
			
		||||
                                       actual_literal, ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ReferenceUtilTest, ReduceToColArray2D) {
 | 
			
		||||
  auto add = [](float lhs, float rhs) { return lhs + rhs; };
 | 
			
		||||
  auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add);
 | 
			
		||||
  auto actual_literal = LiteralUtil::CreateR1<float>(*result);
 | 
			
		||||
  LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, *actual_literal,
 | 
			
		||||
  LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, actual_literal,
 | 
			
		||||
                                       ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -82,7 +82,7 @@ TEST_F(ReferenceUtilTest, ReduceToRowArray2D) {
 | 
			
		||||
  auto add = [](float lhs, float rhs) { return lhs + rhs; };
 | 
			
		||||
  auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add);
 | 
			
		||||
  auto actual_literal = LiteralUtil::CreateR1<float>(*result);
 | 
			
		||||
  LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, *actual_literal,
 | 
			
		||||
  LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, actual_literal,
 | 
			
		||||
                                       ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -90,14 +90,14 @@ TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) {
 | 
			
		||||
  auto result = LiteralUtil::CreateR1<float>(ReferenceUtil::Reduce4DTo1D(
 | 
			
		||||
      Array4D<float>(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2},
 | 
			
		||||
      [](float a, float b) { return a + b; }));
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({0}, *result);
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({0}, result);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ReferenceUtilTest, MapArray2D) {
 | 
			
		||||
  auto identity = [](float value) { return log(exp(value)); };
 | 
			
		||||
  auto result = ReferenceUtil::MapArray2D(*matrix_, identity);
 | 
			
		||||
  auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
 | 
			
		||||
  LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal,
 | 
			
		||||
  LiteralTestUtil::ExpectR2NearArray2D(*matrix_, actual_literal,
 | 
			
		||||
                                       ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -108,7 +108,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) {
 | 
			
		||||
  auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index);
 | 
			
		||||
  auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
 | 
			
		||||
  LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}},
 | 
			
		||||
                                       *actual_literal, ErrorSpec(0.0001));
 | 
			
		||||
                                       actual_literal, ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ReferenceUtilTest, MapArray4D) {
 | 
			
		||||
@ -121,7 +121,7 @@ TEST_F(ReferenceUtilTest, MapArray4D) {
 | 
			
		||||
 | 
			
		||||
  Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
 | 
			
		||||
  expected.FillWithMultiples(2.0f);
 | 
			
		||||
  LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal,
 | 
			
		||||
  LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal,
 | 
			
		||||
                                       ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -138,7 +138,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) {
 | 
			
		||||
 | 
			
		||||
  Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
 | 
			
		||||
  expected.Fill(0.0f);
 | 
			
		||||
  LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal,
 | 
			
		||||
  LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal,
 | 
			
		||||
                                       ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -146,16 +146,16 @@ TEST_F(ReferenceUtilTest, SliceArray2D) {
 | 
			
		||||
  auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}});
 | 
			
		||||
  auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
 | 
			
		||||
 | 
			
		||||
  LiteralTestUtil::ExpectR2Near<float>({{1.f, 2.f}, {4.f, 5.f}},
 | 
			
		||||
                                       *actual_literal, ErrorSpec(0.0001));
 | 
			
		||||
  LiteralTestUtil::ExpectR2Near<float>({{1.f, 2.f}, {4.f, 5.f}}, actual_literal,
 | 
			
		||||
                                       ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ReferenceUtilTest, SliceStridedArray2D) {
 | 
			
		||||
  auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}});
 | 
			
		||||
  auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
 | 
			
		||||
 | 
			
		||||
  LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f}, {4.f, 6.f}},
 | 
			
		||||
                                       *actual_literal, ErrorSpec(0.0001));
 | 
			
		||||
  LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f}, {4.f, 6.f}}, actual_literal,
 | 
			
		||||
                                       ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ReferenceUtilTest, SliceArray3D) {
 | 
			
		||||
@ -167,7 +167,7 @@ TEST_F(ReferenceUtilTest, SliceArray3D) {
 | 
			
		||||
  auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
 | 
			
		||||
 | 
			
		||||
  LiteralTestUtil::ExpectR3Near<float>(
 | 
			
		||||
      {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal,
 | 
			
		||||
      {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, actual_literal,
 | 
			
		||||
      ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -180,8 +180,8 @@ TEST_F(ReferenceUtilTest, SliceStridedArray3D) {
 | 
			
		||||
  auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
 | 
			
		||||
 | 
			
		||||
  LiteralTestUtil::ExpectR3Near<float>(
 | 
			
		||||
      {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}},
 | 
			
		||||
      *actual_literal, ErrorSpec(0.0001));
 | 
			
		||||
      {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, actual_literal,
 | 
			
		||||
      ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ReferenceUtilTest, SliceArray4D) {
 | 
			
		||||
@ -194,7 +194,7 @@ TEST_F(ReferenceUtilTest, SliceArray4D) {
 | 
			
		||||
 | 
			
		||||
  LiteralTestUtil::ExpectR4Near<float>(
 | 
			
		||||
      {{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}},
 | 
			
		||||
      *actual_literal, ErrorSpec(0.0001));
 | 
			
		||||
      actual_literal, ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ReferenceUtilTest, SliceStridedArray4D) {
 | 
			
		||||
@ -208,7 +208,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray4D) {
 | 
			
		||||
  LiteralTestUtil::ExpectR4Near<float>(
 | 
			
		||||
      {{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}},
 | 
			
		||||
        {{100.f, 102.f, 104.f}, {110.f, 112.f, 114.f}}}},
 | 
			
		||||
      *actual_literal, ErrorSpec(0.0001));
 | 
			
		||||
      actual_literal, ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) {
 | 
			
		||||
@ -220,7 +220,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) {
 | 
			
		||||
 | 
			
		||||
  auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
 | 
			
		||||
 | 
			
		||||
  LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
 | 
			
		||||
  LiteralTestUtil::ExpectR3NearArray3D<float>(expected, actual_literal,
 | 
			
		||||
                                              ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -233,7 +233,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithValidPadding) {
 | 
			
		||||
 | 
			
		||||
  auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
 | 
			
		||||
 | 
			
		||||
  LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
 | 
			
		||||
  LiteralTestUtil::ExpectR3NearArray3D<float>(expected, actual_literal,
 | 
			
		||||
                                              ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -268,7 +268,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) {
 | 
			
		||||
 | 
			
		||||
  auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
 | 
			
		||||
 | 
			
		||||
  LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
 | 
			
		||||
  LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
 | 
			
		||||
                                              ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -302,7 +302,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) {
 | 
			
		||||
 | 
			
		||||
  auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
 | 
			
		||||
 | 
			
		||||
  LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
 | 
			
		||||
  LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
 | 
			
		||||
                                              ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -358,7 +358,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) {
 | 
			
		||||
 | 
			
		||||
  auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
 | 
			
		||||
 | 
			
		||||
  LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
 | 
			
		||||
  LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
 | 
			
		||||
                                              ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -411,7 +411,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) {
 | 
			
		||||
 | 
			
		||||
  auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
 | 
			
		||||
 | 
			
		||||
  LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
 | 
			
		||||
  LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
 | 
			
		||||
                                              ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -424,7 +424,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) {
 | 
			
		||||
      [](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c);
 | 
			
		||||
  auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual);
 | 
			
		||||
  LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}},
 | 
			
		||||
                                *actual_literal, ErrorSpec(0.0001));
 | 
			
		||||
                                actual_literal, ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
@ -95,12 +95,11 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) {
 | 
			
		||||
  std::vector<float> expected = {
 | 
			
		||||
      1.85840735, -1.85840735, 2.28318531,   -2.28318531,  -6.42477796,
 | 
			
		||||
      6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327};
 | 
			
		||||
  std::unique_ptr<Literal> expected_literal =
 | 
			
		||||
      LiteralUtil::CreateR1<float>(expected);
 | 
			
		||||
  Literal expected_literal = LiteralUtil::CreateR1<float>(expected);
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer(
 | 
			
		||||
                                                   computation, {}, nullptr));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal,
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(expected_literal, result_literal,
 | 
			
		||||
                                    ErrorSpec(0.0001)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -205,7 +205,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
  HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
 | 
			
		||||
    HloInstruction* zero =
 | 
			
		||||
        computation_->AddInstruction(HloInstruction::CreateConstant(
 | 
			
		||||
            LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique()));
 | 
			
		||||
            LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
 | 
			
		||||
    HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
 | 
			
		||||
    Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape());
 | 
			
		||||
    return computation_->AddInstruction(HloInstruction::CreateReduce(
 | 
			
		||||
@ -527,7 +527,7 @@ static HloInstruction* BuildTupleConstant(HloComputation* computation,
 | 
			
		||||
    return computation->AddInstruction(HloInstruction::CreateTuple(elems));
 | 
			
		||||
  } else {
 | 
			
		||||
    return computation->AddInstruction(
 | 
			
		||||
        HloInstruction::CreateConstant(literal.CloneToUnique()));
 | 
			
		||||
        HloInstruction::CreateConstant(literal.Clone()));
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -546,7 +546,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
 | 
			
		||||
  // If a literal is all the same element replace it with a scalar broadcast.
 | 
			
		||||
  if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
 | 
			
		||||
      constant->literal().IsAllFirst()) {
 | 
			
		||||
    std::unique_ptr<Literal> unique_scalar = absl::make_unique<Literal>(
 | 
			
		||||
    Literal unique_scalar(
 | 
			
		||||
        LiteralUtil::GetFirstScalarLiteral(constant->literal()));
 | 
			
		||||
    HloInstruction* scalar = computation_->AddInstruction(
 | 
			
		||||
        HloInstruction::CreateConstant(std::move(unique_scalar)));
 | 
			
		||||
@ -676,7 +676,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
 | 
			
		||||
        return Status::OK();
 | 
			
		||||
    }
 | 
			
		||||
    auto inverse = computation_->AddInstruction(
 | 
			
		||||
        HloInstruction::CreateConstant((new_literal.CloneToUnique())));
 | 
			
		||||
        HloInstruction::CreateConstant((new_literal.Clone())));
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(auto new_divide,
 | 
			
		||||
                        MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
 | 
			
		||||
    return ReplaceInstruction(divide, new_divide);
 | 
			
		||||
@ -1469,7 +1469,7 @@ Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) {
 | 
			
		||||
  auto* iota = Cast<HloIotaInstruction>(instruction);
 | 
			
		||||
  if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
 | 
			
		||||
    auto zero = computation_->AddInstruction(HloInstruction::CreateConstant(
 | 
			
		||||
        LiteralUtil::Zero(iota->shape().element_type()).CloneToUnique()));
 | 
			
		||||
        LiteralUtil::Zero(iota->shape().element_type()).Clone()));
 | 
			
		||||
    return ReplaceWithNewInstruction(
 | 
			
		||||
        iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {}));
 | 
			
		||||
  }
 | 
			
		||||
@ -1572,7 +1572,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
 | 
			
		||||
  CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
 | 
			
		||||
  if (IsAll(rhs, 0)) {
 | 
			
		||||
    auto one = HloInstruction::CreateConstant(
 | 
			
		||||
        LiteralUtil::One(power->shape().element_type()).CloneToUnique());
 | 
			
		||||
        LiteralUtil::One(power->shape().element_type()).Clone());
 | 
			
		||||
    std::unique_ptr<HloInstruction> ones;
 | 
			
		||||
    if (ShapeUtil::IsScalar(power->shape())) {
 | 
			
		||||
      ones = std::move(one);
 | 
			
		||||
@ -1607,7 +1607,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
 | 
			
		||||
  VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
 | 
			
		||||
  if (IsAll(rhs, -1)) {
 | 
			
		||||
    auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
 | 
			
		||||
        LiteralUtil::One(rhs->shape().element_type()).CloneToUnique()));
 | 
			
		||||
        LiteralUtil::One(rhs->shape().element_type()).Clone()));
 | 
			
		||||
 | 
			
		||||
    // Explicitly broadcast scalar 1 to the output shape, to avoid implicit
 | 
			
		||||
    // broadcast in divide HLO as we are trying to eliminate implicit
 | 
			
		||||
@ -2062,7 +2062,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
 | 
			
		||||
      if (!converted_pad_literal.ok()) {
 | 
			
		||||
        return false;
 | 
			
		||||
      }
 | 
			
		||||
      return *converted_pad_literal.ValueOrDie() == reduce_init_literal;
 | 
			
		||||
      return converted_pad_literal.ValueOrDie() == reduce_init_literal;
 | 
			
		||||
    };
 | 
			
		||||
    // The pad value is usually a constant, so we handle that case and do not
 | 
			
		||||
    // try to get more fancy about proving equivalence in cases beyond that.
 | 
			
		||||
@ -2223,8 +2223,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
 | 
			
		||||
        HloInstruction::CreateBroadcast(
 | 
			
		||||
            convolution->shape(),
 | 
			
		||||
            computation_->AddInstruction(HloInstruction::CreateConstant(
 | 
			
		||||
                LiteralUtil::Zero(convolution->shape().element_type())
 | 
			
		||||
                    .CloneToUnique())),
 | 
			
		||||
                LiteralUtil::Zero(convolution->shape().element_type()))),
 | 
			
		||||
            {}));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -2932,9 +2932,9 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) {
 | 
			
		||||
  HloComputation::Builder builder(TestName());
 | 
			
		||||
  const float constant_scalar = 7.3f;
 | 
			
		||||
  std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
 | 
			
		||||
  std::unique_ptr<Literal> value = LiteralUtil::MakeTuple(
 | 
			
		||||
      {LiteralUtil::CreateR0<float>(constant_scalar).get(),
 | 
			
		||||
       LiteralUtil::CreateR1<float>(constant_vector).get()});
 | 
			
		||||
  Literal elements[] = {LiteralUtil::CreateR0<float>(constant_scalar),
 | 
			
		||||
                        LiteralUtil::CreateR1<float>(constant_vector)};
 | 
			
		||||
  Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
 | 
			
		||||
  builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
 | 
			
		||||
 | 
			
		||||
  auto computation = module().AddEntryComputation(builder.Build());
 | 
			
		||||
 | 
			
		||||
@ -205,11 +205,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
 | 
			
		||||
  const Shape feature_shape = scale->shape();
 | 
			
		||||
 | 
			
		||||
  auto zero_literal = LiteralUtil::CreateR0(0.0f);
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
 | 
			
		||||
  auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
 | 
			
		||||
 | 
			
		||||
  auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
 | 
			
		||||
  auto epsilon = add(HloInstruction::CreateBroadcast(
 | 
			
		||||
      operand_shape,
 | 
			
		||||
      add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {}));
 | 
			
		||||
@ -331,7 +331,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
 | 
			
		||||
  const Shape feature_shape = scale->shape();
 | 
			
		||||
 | 
			
		||||
  auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
 | 
			
		||||
  auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
 | 
			
		||||
      operand_shape,
 | 
			
		||||
      computation_->AddInstruction(
 | 
			
		||||
@ -464,11 +464,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
 | 
			
		||||
  const int64 elements_per_feature_int64 = size_in_elements / feature_count;
 | 
			
		||||
 | 
			
		||||
  auto zero_literal = LiteralUtil::CreateR0(0.0f);
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
 | 
			
		||||
  auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
 | 
			
		||||
 | 
			
		||||
  auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
 | 
			
		||||
  auto epsilon_scalar =
 | 
			
		||||
      add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
 | 
			
		||||
  auto epsilon_activation = add(
 | 
			
		||||
@ -560,7 +560,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
 | 
			
		||||
  auto elements_per_feature_literal =
 | 
			
		||||
      LiteralUtil::CreateR0<float>(elements_per_feature_int64);
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
 | 
			
		||||
                      elements_per_feature_literal->Convert(ptype));
 | 
			
		||||
                      elements_per_feature_literal.Convert(ptype));
 | 
			
		||||
  auto elements_per_feature = add(
 | 
			
		||||
      HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
 | 
			
		||||
  auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output,
 | 
			
		||||
 | 
			
		||||
@ -163,10 +163,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
 | 
			
		||||
  EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
 | 
			
		||||
  EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Equal(
 | 
			
		||||
      *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_a)),
 | 
			
		||||
      LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_a)),
 | 
			
		||||
      dot->operand(0)->literal()));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Equal(
 | 
			
		||||
      *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_b)),
 | 
			
		||||
      LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_b)),
 | 
			
		||||
      dot->operand(1)->literal()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1245,9 +1245,10 @@ TEST_F(BufferAssignmentTest, TupleConstantAsOutput) {
 | 
			
		||||
  // Test that a tuple constant which is forwarded to the computation output
 | 
			
		||||
  // is properly handled.
 | 
			
		||||
  auto builder = HloComputation::Builder(TestName());
 | 
			
		||||
  Literal elements[] = {LiteralUtil::CreateR0<int64>(0),
 | 
			
		||||
                        LiteralUtil::CreateR0<int64>(1)};
 | 
			
		||||
  builder.AddInstruction(HloInstruction::CreateConstant(
 | 
			
		||||
      LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
 | 
			
		||||
                              LiteralUtil::CreateR0<int64>(1).get()})));
 | 
			
		||||
      LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
 | 
			
		||||
 | 
			
		||||
  auto module = CreateNewModule();
 | 
			
		||||
  module->AddEntryComputation(builder.Build());
 | 
			
		||||
 | 
			
		||||
@ -440,15 +440,15 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
 | 
			
		||||
  // computation. The buffer containing {0, 1} is copied by GetTupleElement, and
 | 
			
		||||
  // the buffers containing {3} and 3 are dead.
 | 
			
		||||
  auto builder = HloComputation::Builder(TestName());
 | 
			
		||||
  auto inner_tuple0 =
 | 
			
		||||
      LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
 | 
			
		||||
                              LiteralUtil::CreateR0<int64>(1).get()});
 | 
			
		||||
  auto inner_tuple1 =
 | 
			
		||||
      LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(3).get()});
 | 
			
		||||
  Literal elements0[] = {LiteralUtil::CreateR0<int64>(0),
 | 
			
		||||
                         LiteralUtil::CreateR0<int64>(1)};
 | 
			
		||||
  auto inner_tuple0 = LiteralUtil::MakeTuple({&elements0[0], &elements0[1]});
 | 
			
		||||
  Literal element1 = LiteralUtil::CreateR0<int64>(3);
 | 
			
		||||
  auto inner_tuple1 = LiteralUtil::MakeTuple({&element1});
 | 
			
		||||
  auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
 | 
			
		||||
      LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()})));
 | 
			
		||||
      LiteralUtil::MakeTuple({&inner_tuple0, &inner_tuple1})));
 | 
			
		||||
  builder.AddInstruction(HloInstruction::CreateGetTupleElement(
 | 
			
		||||
      inner_tuple0->shape(), tuple_constant, 0));
 | 
			
		||||
      inner_tuple0.shape(), tuple_constant, 0));
 | 
			
		||||
 | 
			
		||||
  auto module = CreateNewModule();
 | 
			
		||||
  module->AddEntryComputation(builder.Build());
 | 
			
		||||
 | 
			
		||||
@ -214,8 +214,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
 | 
			
		||||
    expanded_filter = add(HloInstruction::CreateConcatenate(
 | 
			
		||||
        expanded_filter_shape, concat_operands, input_feature_dim));
 | 
			
		||||
  }
 | 
			
		||||
  auto zero = add(HloInstruction::CreateConstant(absl::make_unique<Literal>(
 | 
			
		||||
      LiteralUtil::Zero(expanded_filter_shape.element_type()))));
 | 
			
		||||
  auto zero = add(HloInstruction::CreateConstant(
 | 
			
		||||
      LiteralUtil::Zero(expanded_filter_shape.element_type())));
 | 
			
		||||
  auto zero_filter =
 | 
			
		||||
      add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
 | 
			
		||||
  auto new_filter = add(
 | 
			
		||||
 | 
			
		||||
@ -45,7 +45,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
 | 
			
		||||
  auto builder = HloComputation::Builder(TestName());
 | 
			
		||||
  auto input_literal1 = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
 | 
			
		||||
  auto input_literal2 = LiteralUtil::CreateR1<float>({-2.0, -42.0, 2.0});
 | 
			
		||||
  Shape vshape = input_literal1->shape();
 | 
			
		||||
  Shape vshape = input_literal1.shape();
 | 
			
		||||
 | 
			
		||||
  auto input1 = builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateConstant(std::move(input_literal1)));
 | 
			
		||||
@ -78,13 +78,13 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
 | 
			
		||||
  auto result = ExecuteAndTransfer(module->Clone(), {});
 | 
			
		||||
 | 
			
		||||
  // Check the output correctness.
 | 
			
		||||
  LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, *result, error_spec_);
 | 
			
		||||
  LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, result, error_spec_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
 | 
			
		||||
  auto builder = HloComputation::Builder(TestName());
 | 
			
		||||
  auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
 | 
			
		||||
  Shape vshape = input_literal->shape();
 | 
			
		||||
  Shape vshape = input_literal.shape();
 | 
			
		||||
 | 
			
		||||
  auto input = builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateConstant(std::move(input_literal)));
 | 
			
		||||
@ -125,8 +125,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
 | 
			
		||||
  auto result = ExecuteAndTransfer(module->Clone(), {});
 | 
			
		||||
 | 
			
		||||
  // Check the output correctness.
 | 
			
		||||
  LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, *result,
 | 
			
		||||
                                       error_spec_);
 | 
			
		||||
  LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, result, error_spec_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
 | 
			
		||||
@ -135,7 +134,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
 | 
			
		||||
  auto module = CreateNewModule();
 | 
			
		||||
  auto builder = HloComputation::Builder(TestName());
 | 
			
		||||
  auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
 | 
			
		||||
  Shape vshape = input_literal->shape();
 | 
			
		||||
  Shape vshape = input_literal.shape();
 | 
			
		||||
 | 
			
		||||
  auto input = builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateConstant(std::move(input_literal)));
 | 
			
		||||
@ -213,7 +212,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
 | 
			
		||||
 | 
			
		||||
  // Check the output correctness.
 | 
			
		||||
  LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0, 14.0, 40.0, 40.0},
 | 
			
		||||
                                       *result, error_spec_);
 | 
			
		||||
                                       result, error_spec_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
 | 
			
		||||
@ -232,7 +231,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
 | 
			
		||||
  // each fusion instruction to ensure that negate is not duplicated.
 | 
			
		||||
  auto builder = HloComputation::Builder(TestName());
 | 
			
		||||
  auto input_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
 | 
			
		||||
  Shape vshape = input_literal->shape();
 | 
			
		||||
  Shape vshape = input_literal.shape();
 | 
			
		||||
 | 
			
		||||
  auto constant = builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateConstant(std::move(input_literal)));
 | 
			
		||||
 | 
			
		||||
@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase {
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TEST_F(InfeedTest, SingleInfeedR0Bool) {
 | 
			
		||||
  TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
 | 
			
		||||
  TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(InfeedTest, SingleInfeedR1U32) {
 | 
			
		||||
  TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
 | 
			
		||||
  TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(InfeedTest, SingleInfeedR2F32) {
 | 
			
		||||
  TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
 | 
			
		||||
  TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(InfeedTest, SingleInfeedR3F32) {
 | 
			
		||||
  TestInfeedRoundTrip(
 | 
			
		||||
      *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
 | 
			
		||||
                              {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
 | 
			
		||||
      LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
 | 
			
		||||
                             {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
 | 
			
		||||
  const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
 | 
			
		||||
  const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
 | 
			
		||||
 | 
			
		||||
  TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
 | 
			
		||||
  TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
 | 
			
		||||
      {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
 | 
			
		||||
       {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
 | 
			
		||||
      r3_dim0minor));
 | 
			
		||||
 | 
			
		||||
  TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
 | 
			
		||||
  TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
 | 
			
		||||
      {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
 | 
			
		||||
       {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
 | 
			
		||||
      r3_dim0major));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(InfeedTest, SingleInfeedR4S32) {
 | 
			
		||||
  TestInfeedRoundTrip(*LiteralUtil::CreateR4(
 | 
			
		||||
  TestInfeedRoundTrip(LiteralUtil::CreateR4(
 | 
			
		||||
      {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
 | 
			
		||||
       {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(InfeedTest, SingleInfeedTuple) {
 | 
			
		||||
  TestInfeedRoundTrip(
 | 
			
		||||
      *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
 | 
			
		||||
                               LiteralUtil::CreateR0<bool>(false).get()}));
 | 
			
		||||
  TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
      {LiteralUtil::CreateR1<uint32>({1, 2, 3}),
 | 
			
		||||
       LiteralUtil::CreateR0<bool>(false)}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
 | 
			
		||||
  TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
 | 
			
		||||
  TestInfeedRoundTrip(LiteralUtil::MakeTuple({}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Tests Infeed operation used in a while loop, as in the code below. The
 | 
			
		||||
@ -157,21 +157,21 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) {
 | 
			
		||||
 | 
			
		||||
  // Send 5 Infeed data of shape F32[3].
 | 
			
		||||
  ASSERT_IS_OK(
 | 
			
		||||
      client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({1, 2, 3})));
 | 
			
		||||
      client_->TransferToInfeed(LiteralUtil::CreateR1<float>({1, 2, 3})));
 | 
			
		||||
  ASSERT_IS_OK(
 | 
			
		||||
      client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({4, 5, 6})));
 | 
			
		||||
      client_->TransferToInfeed(LiteralUtil::CreateR1<float>({4, 5, 6})));
 | 
			
		||||
  ASSERT_IS_OK(
 | 
			
		||||
      client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({7, 8, 9})));
 | 
			
		||||
      client_->TransferToInfeed(LiteralUtil::CreateR1<float>({7, 8, 9})));
 | 
			
		||||
  ASSERT_IS_OK(
 | 
			
		||||
      client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({10, 11, 12})));
 | 
			
		||||
      client_->TransferToInfeed(LiteralUtil::CreateR1<float>({10, 11, 12})));
 | 
			
		||||
  ASSERT_IS_OK(
 | 
			
		||||
      client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({13, 14, 15})));
 | 
			
		||||
      client_->TransferToInfeed(LiteralUtil::CreateR1<float>({13, 14, 15})));
 | 
			
		||||
 | 
			
		||||
  delete computation_thread;  // Joins the thread.
 | 
			
		||||
  auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  // Only the first 3 infeed data should be added.
 | 
			
		||||
  LiteralTestUtil::ExpectR0Near<float>(45.0f, *result_literal, ErrorSpec{1e-7});
 | 
			
		||||
  LiteralTestUtil::ExpectR0Near<float>(45.0f, result_literal, ErrorSpec{1e-7});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Tests two Infeed operations with a total order. The order is enforced by
 | 
			
		||||
@ -250,17 +250,17 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
 | 
			
		||||
 | 
			
		||||
  // Send the first 4 Infeed data of shape Tuple(F32[2], PRED).
 | 
			
		||||
  ASSERT_IS_OK(client_->TransferToInfeed(
 | 
			
		||||
      *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
 | 
			
		||||
                               LiteralUtil::CreateR0<bool>(true).get()})));
 | 
			
		||||
      LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2}),
 | 
			
		||||
                                        LiteralUtil::CreateR0<bool>(true)})));
 | 
			
		||||
  ASSERT_IS_OK(client_->TransferToInfeed(
 | 
			
		||||
      *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({3, 4}).get(),
 | 
			
		||||
                               LiteralUtil::CreateR0<bool>(true).get()})));
 | 
			
		||||
      LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({3, 4}),
 | 
			
		||||
                                        LiteralUtil::CreateR0<bool>(true)})));
 | 
			
		||||
  ASSERT_IS_OK(client_->TransferToInfeed(
 | 
			
		||||
      *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({5, 6}).get(),
 | 
			
		||||
                               LiteralUtil::CreateR0<bool>(true).get()})));
 | 
			
		||||
      LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({5, 6}),
 | 
			
		||||
                                        LiteralUtil::CreateR0<bool>(true)})));
 | 
			
		||||
  ASSERT_IS_OK(client_->TransferToInfeed(
 | 
			
		||||
      *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8}).get(),
 | 
			
		||||
                               LiteralUtil::CreateR0<bool>(false).get()})));
 | 
			
		||||
      LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8}),
 | 
			
		||||
                                        LiteralUtil::CreateR0<bool>(false)})));
 | 
			
		||||
 | 
			
		||||
  // Asynchronously launch the execution on the device.
 | 
			
		||||
  std::unique_ptr<GlobalData> result;
 | 
			
		||||
@ -275,21 +275,21 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
 | 
			
		||||
  // Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED).
 | 
			
		||||
  sleep(1);
 | 
			
		||||
  ASSERT_IS_OK(client_->TransferToInfeed(
 | 
			
		||||
      *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
 | 
			
		||||
                               LiteralUtil::CreateR0<bool>(true).get()})));
 | 
			
		||||
      LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2, 3}),
 | 
			
		||||
                                        LiteralUtil::CreateR0<bool>(true)})));
 | 
			
		||||
  ASSERT_IS_OK(client_->TransferToInfeed(
 | 
			
		||||
      *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8, 9}).get(),
 | 
			
		||||
                               LiteralUtil::CreateR0<bool>(false).get()})));
 | 
			
		||||
      LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8, 9}),
 | 
			
		||||
                                        LiteralUtil::CreateR0<bool>(false)})));
 | 
			
		||||
  ASSERT_IS_OK(client_->TransferToInfeed(
 | 
			
		||||
      *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
 | 
			
		||||
                               LiteralUtil::CreateR0<bool>(true).get()})));
 | 
			
		||||
      LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({4, 5, 6}),
 | 
			
		||||
                                        LiteralUtil::CreateR0<bool>(true)})));
 | 
			
		||||
 | 
			
		||||
  // Wait for the execution to be done, and transfer the result.
 | 
			
		||||
  delete computation_thread;  // Joins the thread.
 | 
			
		||||
  auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  // Only the first 6 infeed data should be added.
 | 
			
		||||
  LiteralTestUtil::ExpectR0Near<float>(66.0f, *result_literal, ErrorSpec{1e-7});
 | 
			
		||||
  LiteralTestUtil::ExpectR0Near<float>(66.0f, result_literal, ErrorSpec{1e-7});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
@ -41,8 +41,7 @@ class CpuNoAliasTest : public CpuCodegenTest {};
 | 
			
		||||
TEST_F(CpuNoAliasTest, Concat) {
 | 
			
		||||
  HloComputation::Builder builder(TestName());
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal =
 | 
			
		||||
      LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
 | 
			
		||||
  Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
 | 
			
		||||
  auto param_shape = ShapeUtil::MakeShape(F32, {2, 2});
 | 
			
		||||
  HloInstruction* param_x = builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateParameter(0, param_shape, "x"));
 | 
			
		||||
 | 
			
		||||
@ -56,9 +56,9 @@ ENTRY main {
 | 
			
		||||
}
 | 
			
		||||
)";
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}});
 | 
			
		||||
  std::unique_ptr<Literal> rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}});
 | 
			
		||||
  RunTest(hlo_text, {lhs.get(), rhs.get()});
 | 
			
		||||
  Literal lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}});
 | 
			
		||||
  Literal rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}});
 | 
			
		||||
  RunTest(hlo_text, {&lhs, &rhs});
 | 
			
		||||
}
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
 | 
			
		||||
@ -125,7 +125,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
 | 
			
		||||
                       device_memory.size());
 | 
			
		||||
          // Element is array-shaped: transfer array data to device buffer.
 | 
			
		||||
          const auto subliteral = LiteralSlice(literal, index);
 | 
			
		||||
          std::unique_ptr<Literal> relayed_out_literal;
 | 
			
		||||
          Literal relayed_out_literal;
 | 
			
		||||
          const void* source;
 | 
			
		||||
          if (LayoutUtil::Equal(device_subshape.layout(),
 | 
			
		||||
                                subliteral.shape().layout())) {
 | 
			
		||||
@ -138,7 +138,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
 | 
			
		||||
            // Relayout data before transferring.
 | 
			
		||||
            relayed_out_literal = subliteral.Relayout(device_subshape.layout(),
 | 
			
		||||
                                                      /*shape_index=*/{});
 | 
			
		||||
            source = relayed_out_literal->untyped_data();
 | 
			
		||||
            source = relayed_out_literal.untyped_data();
 | 
			
		||||
            TF_RETURN_IF_ERROR(TransferBufferToDevice(
 | 
			
		||||
                stream,
 | 
			
		||||
                /*size=*/GetByteSizeRequirement(device_subshape), source,
 | 
			
		||||
 | 
			
		||||
@ -590,7 +590,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveConstantFilter) {
 | 
			
		||||
  Array4D<float> constant_arr(4, 4, 2, 2);
 | 
			
		||||
  constant_arr.FillIota(0);
 | 
			
		||||
  string constant_str =
 | 
			
		||||
      LiteralUtil::CreateR4FromArray4D(constant_arr)->ToString();
 | 
			
		||||
      LiteralUtil::CreateR4FromArray4D(constant_arr).ToString();
 | 
			
		||||
  ParseAndVerifyModule(absl::StrFormat(R"(
 | 
			
		||||
    HloModule test
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -23,7 +23,6 @@ limitations under the License.
 | 
			
		||||
namespace xla {
 | 
			
		||||
namespace gpu {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// We want the input/output feature counts of an f16 conv to be factors of 8,
 | 
			
		||||
// because without this cudnn can't use tensor cores on the conv.
 | 
			
		||||
static constexpr int64 kDesiredNumFeaturesFactor = 8;
 | 
			
		||||
@ -63,8 +62,8 @@ static HloInstruction* PadInstruction(HloInstruction* instr,
 | 
			
		||||
  HloComputation* comp = instr->parent();
 | 
			
		||||
 | 
			
		||||
  const Shape& shape = instr->shape();
 | 
			
		||||
  auto* zero = comp->AddInstruction(HloInstruction::CreateConstant(
 | 
			
		||||
      LiteralUtil::Zero(shape.element_type()).CloneToUnique()));
 | 
			
		||||
  auto* zero = comp->AddInstruction(
 | 
			
		||||
      HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
 | 
			
		||||
 | 
			
		||||
  PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape));
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -68,9 +68,8 @@ HloInstruction* MaybePaddedAndSlicedInput(
 | 
			
		||||
          conv_window.dimensions(i).base_dilation() - 1);
 | 
			
		||||
    }
 | 
			
		||||
    PrimitiveType element_type = input->shape().element_type();
 | 
			
		||||
    HloInstruction* padding =
 | 
			
		||||
        computation->AddInstruction(HloInstruction::CreateConstant(
 | 
			
		||||
            absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
 | 
			
		||||
    HloInstruction* padding = computation->AddInstruction(
 | 
			
		||||
        HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
 | 
			
		||||
    input = MakePadHlo(input, padding, padding_config).ValueOrDie();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -125,9 +124,8 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window,
 | 
			
		||||
 | 
			
		||||
  HloComputation* computation = kernel->parent();
 | 
			
		||||
  PrimitiveType element_type = kernel->shape().element_type();
 | 
			
		||||
  HloInstruction* padding =
 | 
			
		||||
      computation->AddInstruction(HloInstruction::CreateConstant(
 | 
			
		||||
          absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
 | 
			
		||||
  HloInstruction* padding = computation->AddInstruction(
 | 
			
		||||
      HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
 | 
			
		||||
  return MakePadHlo(kernel, padding, padding_config).ValueOrDie();
 | 
			
		||||
}
 | 
			
		||||
}  // namespace
 | 
			
		||||
@ -236,9 +234,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
 | 
			
		||||
  // Create a new backward convolution replacing the old one.
 | 
			
		||||
  HloComputation* computation = backward_conv->parent();
 | 
			
		||||
  HloInstruction* output = backward_conv->mutable_operand(1);
 | 
			
		||||
  HloInstruction* padding = computation->AddInstruction(
 | 
			
		||||
      HloInstruction::CreateConstant(absl::make_unique<Literal>(
 | 
			
		||||
          LiteralUtil::Zero(input->shape().element_type()))));
 | 
			
		||||
  HloInstruction* padding =
 | 
			
		||||
      computation->AddInstruction(HloInstruction::CreateConstant(
 | 
			
		||||
          LiteralUtil::Zero(input->shape().element_type())));
 | 
			
		||||
  HloInstruction* padded_input =
 | 
			
		||||
      MakePadHlo(input, padding, input_padding_config).ValueOrDie();
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -38,8 +38,7 @@ class GpuCopyTest : public GpuCodegenTest {};
 | 
			
		||||
TEST_F(GpuCopyTest, UseMemcpy) {
 | 
			
		||||
  HloComputation::Builder builder(TestName());
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal =
 | 
			
		||||
      LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
 | 
			
		||||
  Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
 | 
			
		||||
  HloInstruction* constant = builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateConstant(std::move(literal)));
 | 
			
		||||
  builder.AddInstruction(HloInstruction::CreateUnary(
 | 
			
		||||
 | 
			
		||||
@ -53,40 +53,40 @@ class InfeedTest : public ClientLibraryTestBase {
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TEST_F(InfeedTest, SingleInfeedR0Bool) {
 | 
			
		||||
  TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
 | 
			
		||||
  TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(InfeedTest, SingleInfeedR1U32) {
 | 
			
		||||
  TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
 | 
			
		||||
  TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(InfeedTest, SingleInfeedR2F32) {
 | 
			
		||||
  TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
 | 
			
		||||
  TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(InfeedTest, SingleInfeedR3F32) {
 | 
			
		||||
  TestInfeedRoundTrip(
 | 
			
		||||
      *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
 | 
			
		||||
                              {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
 | 
			
		||||
      LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
 | 
			
		||||
                             {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
 | 
			
		||||
  const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
 | 
			
		||||
  const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
 | 
			
		||||
 | 
			
		||||
  TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
 | 
			
		||||
  TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
 | 
			
		||||
      {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
 | 
			
		||||
       {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
 | 
			
		||||
      r3_dim0minor));
 | 
			
		||||
 | 
			
		||||
  TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
 | 
			
		||||
  TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
 | 
			
		||||
      {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
 | 
			
		||||
       {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
 | 
			
		||||
      r3_dim0major));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(InfeedTest, SingleInfeedR4S32) {
 | 
			
		||||
  TestInfeedRoundTrip(*LiteralUtil::CreateR4(
 | 
			
		||||
  TestInfeedRoundTrip(LiteralUtil::CreateR4(
 | 
			
		||||
      {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
 | 
			
		||||
       {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
 | 
			
		||||
}
 | 
			
		||||
@ -95,26 +95,26 @@ TEST_F(InfeedTest, SingleInfeedR4S32) {
 | 
			
		||||
TEST_F(InfeedTest, LargeInfeed) {
 | 
			
		||||
  Array4D<float> array(80, 100, 8, 128);
 | 
			
		||||
  array.FillIota(1.0f);
 | 
			
		||||
  TestInfeedRoundTrip(*LiteralUtil::CreateR4FromArray4D<float>(array));
 | 
			
		||||
  TestInfeedRoundTrip(LiteralUtil::CreateR4FromArray4D<float>(array));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(InfeedTest, SingleInfeedTuple) {
 | 
			
		||||
  TestInfeedRoundTrip(
 | 
			
		||||
      *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
 | 
			
		||||
                               LiteralUtil::CreateR0<bool>(false).get()}));
 | 
			
		||||
  TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
      {LiteralUtil::CreateR1<uint32>({1, 2, 3}),
 | 
			
		||||
       LiteralUtil::CreateR0<bool>(false)}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
 | 
			
		||||
  TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
 | 
			
		||||
  TestInfeedRoundTrip(LiteralUtil::MakeTuple({}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Tests that a large tuple infeed can be handled.
 | 
			
		||||
TEST_F(InfeedTest, SingleInfeedLargeTuple) {
 | 
			
		||||
  Array4D<float> array(40, 100, 8, 128);
 | 
			
		||||
  array.FillIota(1.0f);
 | 
			
		||||
  TestInfeedRoundTrip(*LiteralUtil::MakeTuple(
 | 
			
		||||
      {LiteralUtil::CreateR4FromArray4D<float>(array).get(),
 | 
			
		||||
       LiteralUtil::CreateR0<int32>(5).get()}));
 | 
			
		||||
  TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
      {LiteralUtil::CreateR4FromArray4D<float>(array),
 | 
			
		||||
       LiteralUtil::CreateR0<int32>(5)}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
@ -76,10 +76,10 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
 | 
			
		||||
        continue;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      std::unique_ptr<Literal> result = evaluator->TryEvaluate(instruction);
 | 
			
		||||
      Literal result;
 | 
			
		||||
      // Currently we skip unimplemented operations.
 | 
			
		||||
      // TODO(b/35975797): Fold constant computations for more operations.
 | 
			
		||||
      if (result == nullptr) {
 | 
			
		||||
      if (!evaluator->TryEvaluate(instruction, &result)) {
 | 
			
		||||
        VLOG(2) << "Constant folding failed for instruction: "
 | 
			
		||||
                << instruction->ToString();
 | 
			
		||||
        continue;
 | 
			
		||||
 | 
			
		||||
@ -175,7 +175,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(auto literal,
 | 
			
		||||
                          LiteralUtil::CreateRandomLiteral<F32>(
 | 
			
		||||
                              ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
 | 
			
		||||
  auto literal_clone = literal->Literal::CloneToUnique();
 | 
			
		||||
  auto literal_clone = literal.Clone();
 | 
			
		||||
  HloInstruction* literal_instruction = builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateConstant(std::move(literal)));
 | 
			
		||||
  Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
 | 
			
		||||
@ -198,7 +198,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
 | 
			
		||||
  root->literal().EachCell<NativeT>(
 | 
			
		||||
      [&](absl::Span<const int64> indices, NativeT value) {
 | 
			
		||||
        std::vector<int64> rindexes = Permute(permutation, indices);
 | 
			
		||||
        matched = matched && (value == literal_clone->Get<NativeT>(rindexes));
 | 
			
		||||
        matched = matched && (value == literal_clone.Get<NativeT>(rindexes));
 | 
			
		||||
      });
 | 
			
		||||
  EXPECT_TRUE(matched);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -321,18 +321,17 @@ StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
 | 
			
		||||
  padding_config_dim.set_edge_padding_high(zeros_to_append);
 | 
			
		||||
  *padding_config.add_dimensions() = padding_config_dim;
 | 
			
		||||
 | 
			
		||||
  HloInstruction* zero = computation->AddInstruction(
 | 
			
		||||
      HloInstruction::CreateConstant(absl::make_unique<Literal>(
 | 
			
		||||
          LiteralUtil::Zero(operand->shape().element_type()))));
 | 
			
		||||
  HloInstruction* zero =
 | 
			
		||||
      computation->AddInstruction(HloInstruction::CreateConstant(
 | 
			
		||||
          LiteralUtil::Zero(operand->shape().element_type())));
 | 
			
		||||
  return MakePadHlo(operand, zero, padding_config);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<HloInstruction*> BroadcastZeros(
 | 
			
		||||
    HloComputation* computation, PrimitiveType element_type,
 | 
			
		||||
    absl::Span<const int64> broadcast_dimensions) {
 | 
			
		||||
  HloInstruction* zero =
 | 
			
		||||
      computation->AddInstruction(HloInstruction::CreateConstant(
 | 
			
		||||
          absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
 | 
			
		||||
  HloInstruction* zero = computation->AddInstruction(
 | 
			
		||||
      HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
 | 
			
		||||
  return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
 | 
			
		||||
                          /*result_shape_bounds=*/broadcast_dimensions);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -57,10 +57,10 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
 | 
			
		||||
  entry_computation->set_root_instruction(first_1_dims_collapsed);
 | 
			
		||||
 | 
			
		||||
  HloEvaluator evaluator;
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
 | 
			
		||||
                          evaluator.Evaluate<std::unique_ptr<Literal>>(
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
 | 
			
		||||
                          evaluator.Evaluate<Literal>(
 | 
			
		||||
                              *module, {LiteralUtil::CreateR1<int32>({3, 4})}));
 | 
			
		||||
  CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({3, 4}));
 | 
			
		||||
  CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({3, 4}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
 | 
			
		||||
@ -78,13 +78,13 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
 | 
			
		||||
 | 
			
		||||
  HloEvaluator evaluator;
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      std::unique_ptr<Literal> result_literal,
 | 
			
		||||
      evaluator.Evaluate<std::unique_ptr<Literal>>(
 | 
			
		||||
      Literal result_literal,
 | 
			
		||||
      evaluator.Evaluate<Literal>(
 | 
			
		||||
          *module,
 | 
			
		||||
          {LiteralUtil::CreateR3<int32>(
 | 
			
		||||
              {{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})}));
 | 
			
		||||
  CHECK_EQ(*result_literal,
 | 
			
		||||
           *LiteralUtil::CreateR2<int32>(
 | 
			
		||||
  CHECK_EQ(result_literal,
 | 
			
		||||
           LiteralUtil::CreateR2<int32>(
 | 
			
		||||
               {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -103,10 +103,10 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) {
 | 
			
		||||
 | 
			
		||||
  HloEvaluator evaluator;
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      std::unique_ptr<Literal> result_literal,
 | 
			
		||||
      evaluator.Evaluate<std::unique_ptr<Literal>>(
 | 
			
		||||
          *module, {LiteralUtil::CreateR1<int32>({9, 10})}));
 | 
			
		||||
  CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9, 10}}));
 | 
			
		||||
      Literal result_literal,
 | 
			
		||||
      evaluator.Evaluate<Literal>(*module,
 | 
			
		||||
                                  {LiteralUtil::CreateR1<int32>({9, 10})}));
 | 
			
		||||
  CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9, 10}}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
 | 
			
		||||
@ -124,10 +124,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
 | 
			
		||||
 | 
			
		||||
  HloEvaluator evaluator;
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      std::unique_ptr<Literal> result_literal,
 | 
			
		||||
      evaluator.Evaluate<std::unique_ptr<Literal>>(
 | 
			
		||||
          *module, {LiteralUtil::CreateR1<int32>({9, 10})}));
 | 
			
		||||
  CHECK_EQ(*result_literal, *LiteralUtil::CreateR3<int32>({{{9, 10}}}));
 | 
			
		||||
      Literal result_literal,
 | 
			
		||||
      evaluator.Evaluate<Literal>(*module,
 | 
			
		||||
                                  {LiteralUtil::CreateR1<int32>({9, 10})}));
 | 
			
		||||
  CHECK_EQ(result_literal, LiteralUtil::CreateR3<int32>({{{9, 10}}}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
 | 
			
		||||
@ -144,10 +144,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
 | 
			
		||||
  entry_computation->set_root_instruction(with_2_degenerate_dims_prepended);
 | 
			
		||||
 | 
			
		||||
  HloEvaluator evaluator;
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
 | 
			
		||||
                          evaluator.Evaluate<std::unique_ptr<Literal>>(
 | 
			
		||||
                              *module, {LiteralUtil::CreateR0<int32>(9)}));
 | 
			
		||||
  CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9}}));
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      Literal result_literal,
 | 
			
		||||
      evaluator.Evaluate<Literal>(*module, {LiteralUtil::CreateR0<int32>(9)}));
 | 
			
		||||
  CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9}}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
 | 
			
		||||
@ -165,11 +165,11 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
 | 
			
		||||
 | 
			
		||||
  HloEvaluator evaluator;
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      std::unique_ptr<Literal> result_literal,
 | 
			
		||||
      evaluator.Evaluate<std::unique_ptr<Literal>>(
 | 
			
		||||
      Literal result_literal,
 | 
			
		||||
      evaluator.Evaluate<Literal>(
 | 
			
		||||
          *module, {LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6})}));
 | 
			
		||||
  CHECK_EQ(*result_literal,
 | 
			
		||||
           *LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
 | 
			
		||||
  CHECK_EQ(result_literal,
 | 
			
		||||
           LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
 | 
			
		||||
@ -187,10 +187,10 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
 | 
			
		||||
  entry_computation->set_root_instruction(zero_padded_param);
 | 
			
		||||
 | 
			
		||||
  HloEvaluator evaluator;
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
 | 
			
		||||
                          evaluator.Evaluate<std::unique_ptr<Literal>>(
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
 | 
			
		||||
                          evaluator.Evaluate<Literal>(
 | 
			
		||||
                              *module, {LiteralUtil::CreateR1<int32>({3, 4})}));
 | 
			
		||||
  CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
 | 
			
		||||
  CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
 | 
			
		||||
@ -208,10 +208,10 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
 | 
			
		||||
  entry_computation->set_root_instruction(zeros);
 | 
			
		||||
 | 
			
		||||
  HloEvaluator evaluator;
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
 | 
			
		||||
                          evaluator.Evaluate<std::unique_ptr<Literal>>(
 | 
			
		||||
                              *module, {LiteralUtil::CreateR0<int32>(0)}));
 | 
			
		||||
  CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      Literal result_literal,
 | 
			
		||||
      evaluator.Evaluate<Literal>(*module, {LiteralUtil::CreateR0<int32>(0)}));
 | 
			
		||||
  CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
 | 
			
		||||
@ -229,11 +229,11 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
 | 
			
		||||
  entry_computation->set_root_instruction(zeros);
 | 
			
		||||
 | 
			
		||||
  HloEvaluator evaluator;
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
 | 
			
		||||
                          evaluator.Evaluate<std::unique_ptr<Literal>>(
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
 | 
			
		||||
                          evaluator.Evaluate<Literal>(
 | 
			
		||||
                              *module, {LiteralUtil::CreateR0<float>(0.0f)}));
 | 
			
		||||
  CHECK_EQ(*result_literal,
 | 
			
		||||
           *LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
 | 
			
		||||
  CHECK_EQ(result_literal,
 | 
			
		||||
           LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
@ -73,7 +73,7 @@ TEST_F(HloCseTest, CombineTwoConstants) {
 | 
			
		||||
 | 
			
		||||
  auto result = ExecuteAndTransfer(module->Clone(), {});
 | 
			
		||||
  auto expected = LiteralUtil::CreateR0<float>(84.0);
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
 | 
			
		||||
@ -105,7 +105,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
 | 
			
		||||
 | 
			
		||||
  auto result = ExecuteAndTransfer(module->Clone(), {});
 | 
			
		||||
  auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
 | 
			
		||||
@ -135,7 +135,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
 | 
			
		||||
 | 
			
		||||
  auto result = ExecuteAndTransfer(module->Clone(), {});
 | 
			
		||||
  auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
 | 
			
		||||
 | 
			
		||||
@ -54,9 +54,8 @@ namespace xla {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
template <typename OperandT>
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
 | 
			
		||||
                                           LiteralSlice lhs_literal,
 | 
			
		||||
                                           LiteralSlice rhs_literal) {
 | 
			
		||||
StatusOr<Literal> Compare(const Shape& shape, HloOpcode opcode,
 | 
			
		||||
                          LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
 | 
			
		||||
  std::function<bool(OperandT, OperandT)> compare_op;
 | 
			
		||||
  switch (opcode) {
 | 
			
		||||
    case HloOpcode::kEq:
 | 
			
		||||
@ -94,9 +93,9 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
 | 
			
		||||
                 << HloOpcodeString(opcode);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto result = absl::make_unique<Literal>(shape);
 | 
			
		||||
  Literal result(shape);
 | 
			
		||||
  TF_RETURN_IF_ERROR(
 | 
			
		||||
      result->Populate<bool>([&](absl::Span<const int64> multi_index) {
 | 
			
		||||
      result.Populate<bool>([&](absl::Span<const int64> multi_index) {
 | 
			
		||||
        return compare_op(lhs_literal.Get<OperandT>(multi_index),
 | 
			
		||||
                          rhs_literal.Get<OperandT>(multi_index));
 | 
			
		||||
      }));
 | 
			
		||||
@ -105,9 +104,9 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
 | 
			
		||||
    const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal,
 | 
			
		||||
    LiteralSlice rhs_literal) {
 | 
			
		||||
StatusOr<Literal> Compare<complex64>(const Shape& shape, HloOpcode opcode,
 | 
			
		||||
                                     LiteralSlice lhs_literal,
 | 
			
		||||
                                     LiteralSlice rhs_literal) {
 | 
			
		||||
  std::function<bool(complex64, complex64)> compare_op;
 | 
			
		||||
  switch (opcode) {
 | 
			
		||||
    case HloOpcode::kEq:
 | 
			
		||||
@ -125,9 +124,9 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
 | 
			
		||||
                 << HloOpcodeString(opcode);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto result = absl::make_unique<Literal>(shape);
 | 
			
		||||
  Literal result(shape);
 | 
			
		||||
  TF_RETURN_IF_ERROR(
 | 
			
		||||
      result->Populate<bool>([&](absl::Span<const int64> multi_index) {
 | 
			
		||||
      result.Populate<bool>([&](absl::Span<const int64> multi_index) {
 | 
			
		||||
        return compare_op(lhs_literal.Get<complex64>(multi_index),
 | 
			
		||||
                          rhs_literal.Get<complex64>(multi_index));
 | 
			
		||||
      }));
 | 
			
		||||
@ -193,7 +192,7 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename LiteralPtr>
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
 | 
			
		||||
StatusOr<Literal> HloEvaluator::Evaluate(
 | 
			
		||||
    const HloModule& module, absl::Span<const LiteralPtr> arg_literals) {
 | 
			
		||||
  XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString());
 | 
			
		||||
 | 
			
		||||
@ -206,11 +205,21 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
 | 
			
		||||
  TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this));
 | 
			
		||||
 | 
			
		||||
  return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction())
 | 
			
		||||
      .CloneToUnique();
 | 
			
		||||
      .Clone();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
 | 
			
		||||
    const HloModule& module, absl::Span<const Literal> arg_literals) {
 | 
			
		||||
  std::vector<const Literal*> arg_literal_ptrs;
 | 
			
		||||
  for (const auto& literal_ptr : arg_literals) {
 | 
			
		||||
    arg_literal_ptrs.push_back(&literal_ptr);
 | 
			
		||||
  }
 | 
			
		||||
  return Evaluate<const Literal*>(module, arg_literal_ptrs);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename LiteralPtr>
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
 | 
			
		||||
StatusOr<Literal> HloEvaluator::Evaluate(
 | 
			
		||||
    const HloComputation& computation,
 | 
			
		||||
    absl::Span<const LiteralPtr> arg_literals) {
 | 
			
		||||
  CHECK(computation.parent() != nullptr);
 | 
			
		||||
@ -224,11 +233,21 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TF_RETURN_IF_ERROR(computation.Accept(this));
 | 
			
		||||
  return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique();
 | 
			
		||||
  return GetEvaluatedLiteralFor(computation.root_instruction()).Clone();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
 | 
			
		||||
    const HloComputation& computation, absl::Span<const Literal> arg_literals) {
 | 
			
		||||
  std::vector<const Literal*> arg_literal_ptrs;
 | 
			
		||||
  for (const auto& literal_ptr : arg_literals) {
 | 
			
		||||
    arg_literal_ptrs.push_back(&literal_ptr);
 | 
			
		||||
  }
 | 
			
		||||
  return Evaluate<const Literal*>(computation, arg_literal_ptrs);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename LiteralPtr>
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
 | 
			
		||||
StatusOr<Literal> HloEvaluator::Evaluate(
 | 
			
		||||
    HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals) {
 | 
			
		||||
  TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
 | 
			
		||||
 | 
			
		||||
@ -247,18 +266,27 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
 | 
			
		||||
              << input_literal->ToString();
 | 
			
		||||
      TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape()));
 | 
			
		||||
 | 
			
		||||
      evaluated_[operand] = input_literal->CloneToUnique();
 | 
			
		||||
      evaluated_[operand] = input_literal->Clone();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TF_RETURN_IF_ERROR(Preprocess(instruction));
 | 
			
		||||
  TF_RETURN_IF_ERROR(instruction->Visit(this));
 | 
			
		||||
  TF_RETURN_IF_ERROR(Postprocess(instruction));
 | 
			
		||||
  return GetEvaluatedLiteralFor(instruction).CloneToUnique();
 | 
			
		||||
  return GetEvaluatedLiteralFor(instruction).Clone();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
 | 
			
		||||
    HloInstruction* instruction) {
 | 
			
		||||
template <>
 | 
			
		||||
StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
 | 
			
		||||
    HloInstruction* instruction, absl::Span<const Literal> arg_literals) {
 | 
			
		||||
  std::vector<const Literal*> arg_literal_ptrs;
 | 
			
		||||
  for (const auto& literal : arg_literals) {
 | 
			
		||||
    arg_literal_ptrs.push_back(&literal);
 | 
			
		||||
  }
 | 
			
		||||
  return Evaluate<const Literal*>(instruction, arg_literal_ptrs);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<Literal> HloEvaluator::Evaluate(HloInstruction* instruction) {
 | 
			
		||||
  if (instruction->opcode() == HloOpcode::kParameter) {
 | 
			
		||||
    return tensorflow::errors::FailedPrecondition(
 | 
			
		||||
        "Cannot evaluate a parameter.");
 | 
			
		||||
@ -274,21 +302,22 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
 | 
			
		||||
  TF_RETURN_IF_ERROR(Preprocess(instruction));
 | 
			
		||||
  TF_RETURN_IF_ERROR(instruction->Visit(this));
 | 
			
		||||
  TF_RETURN_IF_ERROR(Postprocess(instruction));
 | 
			
		||||
  return GetEvaluatedLiteralFor(instruction).CloneToUnique();
 | 
			
		||||
  return GetEvaluatedLiteralFor(instruction).Clone();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<Literal> HloEvaluator::TryEvaluate(
 | 
			
		||||
    HloInstruction* instruction) {
 | 
			
		||||
bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) {
 | 
			
		||||
  CHECK(result != nullptr);
 | 
			
		||||
  auto result_or = Evaluate(instruction);
 | 
			
		||||
  if (!result_or.ok()) {
 | 
			
		||||
    VLOG(1) << "TryEvaluate failed:" << result_or.status();
 | 
			
		||||
    return nullptr;
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return result_or.ConsumeValueOrDie();
 | 
			
		||||
  *result = result_or.ConsumeValueOrDie();
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
 | 
			
		||||
StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
 | 
			
		||||
    const HloInstruction* instruction,
 | 
			
		||||
    const std::unordered_map<const HloInstruction*, const Literal*>&
 | 
			
		||||
        substitutions) {
 | 
			
		||||
@ -299,7 +328,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
 | 
			
		||||
      owned_operands.push_back(operand->Clone());
 | 
			
		||||
    } else {
 | 
			
		||||
      owned_operands.push_back(
 | 
			
		||||
          HloInstruction::CreateConstant(it->second->CloneToUnique()));
 | 
			
		||||
          HloInstruction::CreateConstant(it->second->Clone()));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -316,12 +345,12 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp(
 | 
			
		||||
StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
 | 
			
		||||
    HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
 | 
			
		||||
  std::unique_ptr<HloInstruction> lhs_instr =
 | 
			
		||||
      HloInstruction::CreateConstant(lhs.CloneToUnique());
 | 
			
		||||
      HloInstruction::CreateConstant(lhs.Clone());
 | 
			
		||||
  std::unique_ptr<HloInstruction> rhs_instr =
 | 
			
		||||
      HloInstruction::CreateConstant(rhs.CloneToUnique());
 | 
			
		||||
      HloInstruction::CreateConstant(rhs.Clone());
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<HloInstruction> cloned_instruction =
 | 
			
		||||
      HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(),
 | 
			
		||||
@ -331,10 +360,10 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp(
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
 | 
			
		||||
StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
 | 
			
		||||
    HloOpcode opcode, const Literal& operand) {
 | 
			
		||||
  std::unique_ptr<HloInstruction> operand_instr =
 | 
			
		||||
      HloInstruction::CreateConstant(operand.CloneToUnique());
 | 
			
		||||
      HloInstruction::CreateConstant(operand.Clone());
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<HloInstruction> cloned_instruction =
 | 
			
		||||
      HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get());
 | 
			
		||||
@ -343,14 +372,14 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp(
 | 
			
		||||
StatusOr<Literal> HloEvaluator::EvaluateDotOp(
 | 
			
		||||
    const DotDimensionNumbers& dim_numbers,
 | 
			
		||||
    const PrecisionConfig& precision_config, const Literal& lhs,
 | 
			
		||||
    const Literal& rhs) {
 | 
			
		||||
  std::unique_ptr<HloInstruction> lhs_instr =
 | 
			
		||||
      HloInstruction::CreateConstant(lhs.CloneToUnique());
 | 
			
		||||
      HloInstruction::CreateConstant(lhs.Clone());
 | 
			
		||||
  std::unique_ptr<HloInstruction> rhs_instr =
 | 
			
		||||
      HloInstruction::CreateConstant(rhs.CloneToUnique());
 | 
			
		||||
      HloInstruction::CreateConstant(rhs.Clone());
 | 
			
		||||
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(
 | 
			
		||||
      Shape dot_shape,
 | 
			
		||||
@ -371,7 +400,7 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
 | 
			
		||||
      << ", but input literal shape is: "
 | 
			
		||||
      << ShapeUtil::HumanString(input_literal->shape());
 | 
			
		||||
 | 
			
		||||
  evaluated_[parameter] = input_literal->CloneToUnique();
 | 
			
		||||
  evaluated_[parameter] = input_literal->Clone();
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -421,7 +450,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
 | 
			
		||||
 | 
			
		||||
  for (auto operand : operands) {
 | 
			
		||||
    const Shape& operand_shape = operand->shape();
 | 
			
		||||
    TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
 | 
			
		||||
    TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
 | 
			
		||||
        GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
 | 
			
		||||
        AsInt64Slice(operand_shape.dimensions())));
 | 
			
		||||
    dest_indices[concat_dim] +=
 | 
			
		||||
@ -824,7 +853,7 @@ class OutputOffsetIndexToInputIndex {
 | 
			
		||||
// there is one) to `reshaped_start_indices`.
 | 
			
		||||
static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
 | 
			
		||||
    int64 index_vector_dim, const Literal& start_indices,
 | 
			
		||||
    std::unique_ptr<Literal>* reshaped_start_indices) {
 | 
			
		||||
    Literal* reshaped_start_indices) {
 | 
			
		||||
  if (start_indices.shape().dimensions_size() != index_vector_dim) {
 | 
			
		||||
    return std::cref(start_indices);
 | 
			
		||||
  }
 | 
			
		||||
@ -834,16 +863,16 @@ static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
 | 
			
		||||
  new_shape.push_back(1);
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
 | 
			
		||||
                      start_indices.Reshape(new_shape));
 | 
			
		||||
  return std::cref(**reshaped_start_indices);
 | 
			
		||||
  return std::cref(*reshaped_start_indices);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status HloEvaluator::HandleGather(HloInstruction* gather) {
 | 
			
		||||
  std::unique_ptr<Literal> result = Literal::CreateFromShape(gather->shape());
 | 
			
		||||
  Literal result = Literal::CreateFromShape(gather->shape());
 | 
			
		||||
  const Shape& shape = gather->shape();
 | 
			
		||||
  const GatherDimensionNumbers& dim_numbers =
 | 
			
		||||
      gather->gather_dimension_numbers();
 | 
			
		||||
  const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
 | 
			
		||||
  std::unique_ptr<Literal> reshaped_start_indices;
 | 
			
		||||
  Literal reshaped_start_indices;
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(
 | 
			
		||||
      const Literal& start_indices,
 | 
			
		||||
      ReshapedGatherIndices(dim_numbers.index_vector_dim(),
 | 
			
		||||
@ -908,7 +937,7 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
 | 
			
		||||
      DCHECK_LT(input_index[i], operand_shape.dimensions(i));
 | 
			
		||||
    }
 | 
			
		||||
    TF_RETURN_IF_ERROR(
 | 
			
		||||
        result->CopyElementFrom(operand, input_index, output_index));
 | 
			
		||||
        result.CopyElementFrom(operand, input_index, output_index));
 | 
			
		||||
    return true;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
@ -977,18 +1006,16 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
 | 
			
		||||
 | 
			
		||||
  const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
 | 
			
		||||
 | 
			
		||||
  evaluated_[get_tuple_element] = absl::make_unique<Literal>(
 | 
			
		||||
      ShapeUtil::GetTupleElementShape(operand->shape(), index));
 | 
			
		||||
  return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal,
 | 
			
		||||
                                                 /*dest_shape_index=*/{},
 | 
			
		||||
                                                 /*src_shape_index=*/{index});
 | 
			
		||||
  evaluated_[get_tuple_element] =
 | 
			
		||||
      Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index));
 | 
			
		||||
  return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal,
 | 
			
		||||
                                                /*dest_shape_index=*/{},
 | 
			
		||||
                                                /*src_shape_index=*/{index});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status HloEvaluator::HandleCopy(HloInstruction* copy) {
 | 
			
		||||
  TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
 | 
			
		||||
 | 
			
		||||
  auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique();
 | 
			
		||||
  evaluated_[copy] = std::move(result);
 | 
			
		||||
  evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone();
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1004,7 +1031,7 @@ Status HloEvaluator::HandleCall(HloInstruction* call) {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  HloEvaluator embedded_evaluator;
 | 
			
		||||
  std::unique_ptr<Literal> result =
 | 
			
		||||
  Literal result =
 | 
			
		||||
      embedded_evaluator.Evaluate<const Literal*>(*computation, arg_literals)
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
@ -1036,7 +1063,7 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  HloEvaluator embedded_evaluator;
 | 
			
		||||
  std::unique_ptr<Literal> result =
 | 
			
		||||
  Literal result =
 | 
			
		||||
      embedded_evaluator
 | 
			
		||||
          .Evaluate<const Literal*>(*readded_computation, arg_literals)
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
@ -1056,7 +1083,7 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) {
 | 
			
		||||
  auto* false_computation = conditional->false_computation();
 | 
			
		||||
 | 
			
		||||
  HloEvaluator embedded_evaluator;
 | 
			
		||||
  std::unique_ptr<Literal> result;
 | 
			
		||||
  Literal result;
 | 
			
		||||
  if (pred.Get<bool>({})) {
 | 
			
		||||
    result = embedded_evaluator
 | 
			
		||||
                 .Evaluate<const Literal*>(*true_computation,
 | 
			
		||||
@ -1081,9 +1108,9 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) {
 | 
			
		||||
  // If predicate is of scalar type, no element-wise selection would be needed.
 | 
			
		||||
  if (ShapeUtil::IsScalar(pred.shape())) {
 | 
			
		||||
    if (pred.Get<bool>({})) {
 | 
			
		||||
      evaluated_[select] = on_true.CloneToUnique();
 | 
			
		||||
      evaluated_[select] = on_true.Clone();
 | 
			
		||||
    } else {
 | 
			
		||||
      evaluated_[select] = on_false.CloneToUnique();
 | 
			
		||||
      evaluated_[select] = on_false.Clone();
 | 
			
		||||
    }
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
@ -1097,9 +1124,9 @@ Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) {
 | 
			
		||||
  const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2));
 | 
			
		||||
 | 
			
		||||
  if (pred.Get<bool>({})) {
 | 
			
		||||
    evaluated_[tuple_select] = on_true.CloneToUnique();
 | 
			
		||||
    evaluated_[tuple_select] = on_true.Clone();
 | 
			
		||||
  } else {
 | 
			
		||||
    evaluated_[tuple_select] = on_false.CloneToUnique();
 | 
			
		||||
    evaluated_[tuple_select] = on_false.Clone();
 | 
			
		||||
  }
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
@ -1108,7 +1135,7 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
 | 
			
		||||
  HloComputation* cond_comp = while_hlo->while_condition();
 | 
			
		||||
  HloComputation* body_comp = while_hlo->while_body();
 | 
			
		||||
  // Initialize the loop carried valued with the input to the While instruction.
 | 
			
		||||
  auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).CloneToUnique();
 | 
			
		||||
  auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone();
 | 
			
		||||
  bool keep_going = true;
 | 
			
		||||
  int64 iteration_count = 0;
 | 
			
		||||
  HloEvaluator cond_evaluator(max_loop_iterations_);
 | 
			
		||||
@ -1118,13 +1145,13 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
 | 
			
		||||
      return InvalidArgument("Loop %s exceeded loop iteration limit (%d).",
 | 
			
		||||
                             while_hlo->name(), max_loop_iterations_);
 | 
			
		||||
    }
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate<Literal*>(
 | 
			
		||||
                                           *cond_comp, {lcv.get()}));
 | 
			
		||||
    keep_going = cond_val->GetFirstElement<bool>();
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(auto cond_val,
 | 
			
		||||
                        cond_evaluator.Evaluate<Literal*>(*cond_comp, {&lcv}));
 | 
			
		||||
    keep_going = cond_val.GetFirstElement<bool>();
 | 
			
		||||
    if (keep_going) {
 | 
			
		||||
      TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate<Literal*>(
 | 
			
		||||
                                             *body_comp, {lcv.get()}));
 | 
			
		||||
      VLOG(3) << "Loop iteration result: " << body_val->ToString();
 | 
			
		||||
                                             *body_comp, {&lcv}));
 | 
			
		||||
      VLOG(3) << "Loop iteration result: " << body_val.ToString();
 | 
			
		||||
      lcv = std::move(body_val);
 | 
			
		||||
      cond_evaluator.ResetVisitStates();
 | 
			
		||||
      loop_body_evaluator.ResetVisitStates();
 | 
			
		||||
@ -1139,9 +1166,9 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
 | 
			
		||||
// hoops to make this work.
 | 
			
		||||
namespace {
 | 
			
		||||
template <typename KeyType, typename ValueType>
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
 | 
			
		||||
    HloInstruction* sort, const Literal& keys_literal,
 | 
			
		||||
    const Literal& values_literal) {
 | 
			
		||||
StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort,
 | 
			
		||||
                                       const Literal& keys_literal,
 | 
			
		||||
                                       const Literal& values_literal) {
 | 
			
		||||
  auto rank = ShapeUtil::Rank(keys_literal.shape());
 | 
			
		||||
  TF_RET_CHECK(
 | 
			
		||||
      ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape()))
 | 
			
		||||
@ -1179,57 +1206,55 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
 | 
			
		||||
      result_keys.push_back(key_value.first);
 | 
			
		||||
      result_values.push_back(key_value.second);
 | 
			
		||||
    }
 | 
			
		||||
    auto result_keys_literal = absl::make_unique<Literal>(keys_literal.shape());
 | 
			
		||||
    result_keys_literal->PopulateR1(absl::Span<const KeyType>(result_keys));
 | 
			
		||||
    auto result_values_literal =
 | 
			
		||||
        absl::make_unique<Literal>(values_literal.shape());
 | 
			
		||||
    result_values_literal->PopulateR1(
 | 
			
		||||
    Literal result_keys_literal(keys_literal.shape());
 | 
			
		||||
    result_keys_literal.PopulateR1(absl::Span<const KeyType>(result_keys));
 | 
			
		||||
    Literal result_values_literal(values_literal.shape());
 | 
			
		||||
    result_values_literal.PopulateR1(
 | 
			
		||||
        absl::Span<const ValueType>(result_values));
 | 
			
		||||
    return std::make_pair(std::move(result_keys_literal),
 | 
			
		||||
                          std::move(result_values_literal));
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> result_tuple;
 | 
			
		||||
  Literal result_tuple;
 | 
			
		||||
  if (rank == 1) {
 | 
			
		||||
    auto result_pair = sort_r1(keys_literal, values_literal);
 | 
			
		||||
    result_tuple = LiteralUtil::MakeTuple(
 | 
			
		||||
        {result_pair.first.get(), result_pair.second.get()});
 | 
			
		||||
    result_tuple =
 | 
			
		||||
        LiteralUtil::MakeTuple({&result_pair.first, &result_pair.second});
 | 
			
		||||
  } else {
 | 
			
		||||
    // For R2 sort, the desired semantics are to sort each matrix row
 | 
			
		||||
    // independently.
 | 
			
		||||
    auto keys_result_literal = absl::make_unique<Literal>(keys_literal.shape());
 | 
			
		||||
    auto values_result_literal =
 | 
			
		||||
        absl::make_unique<Literal>(values_literal.shape());
 | 
			
		||||
    Literal keys_result_literal(keys_literal.shape());
 | 
			
		||||
    Literal values_result_literal(values_literal.shape());
 | 
			
		||||
    int64 r1_length = keys_literal.shape().dimensions(1);
 | 
			
		||||
    for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) {
 | 
			
		||||
      TF_ASSIGN_OR_RETURN(auto keys_r1_slice,
 | 
			
		||||
                          keys_literal.Slice({row, 0}, {row + 1, r1_length})
 | 
			
		||||
                              ->Reshape({r1_length}));
 | 
			
		||||
                              .Reshape({r1_length}));
 | 
			
		||||
      TF_ASSIGN_OR_RETURN(auto values_r1_slice,
 | 
			
		||||
                          values_literal.Slice({row, 0}, {row + 1, r1_length})
 | 
			
		||||
                              ->Reshape({r1_length}));
 | 
			
		||||
      auto r1_result_pair = sort_r1(*keys_r1_slice, *values_r1_slice);
 | 
			
		||||
                              .Reshape({r1_length}));
 | 
			
		||||
      auto r1_result_pair = sort_r1(keys_r1_slice, values_r1_slice);
 | 
			
		||||
      TF_ASSIGN_OR_RETURN(auto sorted_keys,
 | 
			
		||||
                          r1_result_pair.first->Reshape({1, r1_length}));
 | 
			
		||||
                          r1_result_pair.first.Reshape({1, r1_length}));
 | 
			
		||||
      TF_ASSIGN_OR_RETURN(auto sorted_values,
 | 
			
		||||
                          r1_result_pair.second->Reshape({1, r1_length}));
 | 
			
		||||
      TF_RETURN_IF_ERROR(keys_result_literal->CopySliceFrom(
 | 
			
		||||
          *sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
 | 
			
		||||
      TF_RETURN_IF_ERROR(values_result_literal->CopySliceFrom(
 | 
			
		||||
          *sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
 | 
			
		||||
                          r1_result_pair.second.Reshape({1, r1_length}));
 | 
			
		||||
      TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom(
 | 
			
		||||
          sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
 | 
			
		||||
      TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom(
 | 
			
		||||
          sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
 | 
			
		||||
    }
 | 
			
		||||
    result_tuple = LiteralUtil::MakeTuple(
 | 
			
		||||
        {keys_result_literal.get(), values_result_literal.get()});
 | 
			
		||||
    result_tuple =
 | 
			
		||||
        LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal});
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString();
 | 
			
		||||
  VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
 | 
			
		||||
  return std::move(result_tuple);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename KeyType>
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> EvaluateSortCurried(
 | 
			
		||||
    HloInstruction* sort, const Literal& keys_literal,
 | 
			
		||||
    const Literal& values_literal) {
 | 
			
		||||
StatusOr<Literal> EvaluateSortCurried(HloInstruction* sort,
 | 
			
		||||
                                      const Literal& keys_literal,
 | 
			
		||||
                                      const Literal& values_literal) {
 | 
			
		||||
  switch (sort->operand(1)->shape().element_type()) {
 | 
			
		||||
    case F32:
 | 
			
		||||
      return EvaluateSortInternal<KeyType, float>(sort, keys_literal,
 | 
			
		||||
@ -1248,9 +1273,9 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortCurried(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> EvaluateSort(HloInstruction* sort,
 | 
			
		||||
                                                const Literal& keys_literal,
 | 
			
		||||
                                                const Literal& values_literal) {
 | 
			
		||||
StatusOr<Literal> EvaluateSort(HloInstruction* sort,
 | 
			
		||||
                               const Literal& keys_literal,
 | 
			
		||||
                               const Literal& values_literal) {
 | 
			
		||||
  switch (sort->operand(0)->shape().element_type()) {
 | 
			
		||||
    case F32:
 | 
			
		||||
      return EvaluateSortCurried<float>(sort, keys_literal, values_literal);
 | 
			
		||||
@ -1319,28 +1344,14 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) {
 | 
			
		||||
 | 
			
		||||
// Explicit instantiation of templatized Evaluate* methods.
 | 
			
		||||
//
 | 
			
		||||
template StatusOr<std::unique_ptr<Literal>>
 | 
			
		||||
HloEvaluator::Evaluate<const Literal*>(
 | 
			
		||||
template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
 | 
			
		||||
    const HloModule& module, absl::Span<const Literal* const> arg_literals);
 | 
			
		||||
template StatusOr<std::unique_ptr<Literal>>
 | 
			
		||||
HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
 | 
			
		||||
    const HloModule& module,
 | 
			
		||||
    absl::Span<const std::unique_ptr<Literal>> arg_literals);
 | 
			
		||||
 | 
			
		||||
template StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate<
 | 
			
		||||
    const Literal*>(const HloComputation& computation,
 | 
			
		||||
                    absl::Span<const Literal* const> arg_literals);
 | 
			
		||||
template StatusOr<std::unique_ptr<Literal>>
 | 
			
		||||
HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
 | 
			
		||||
template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
 | 
			
		||||
    const HloComputation& computation,
 | 
			
		||||
    absl::Span<const std::unique_ptr<Literal>> arg_literals);
 | 
			
		||||
    absl::Span<const Literal* const> arg_literals);
 | 
			
		||||
 | 
			
		||||
template StatusOr<std::unique_ptr<Literal>>
 | 
			
		||||
HloEvaluator::Evaluate<const Literal*>(
 | 
			
		||||
template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
 | 
			
		||||
    HloInstruction* instruction, absl::Span<const Literal* const> arg_literals);
 | 
			
		||||
template StatusOr<std::unique_ptr<Literal>>
 | 
			
		||||
HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
 | 
			
		||||
    HloInstruction* instruction,
 | 
			
		||||
    absl::Span<const std::unique_ptr<Literal>> arg_literals);
 | 
			
		||||
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
 | 
			
		||||
@ -47,11 +47,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
 | 
			
		||||
  // Precondition: The indices of arg_literals correspond to the parameter
 | 
			
		||||
  // numbers of the HLO parameters in the computation. See comment below for an
 | 
			
		||||
  // example.
 | 
			
		||||
  // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
 | 
			
		||||
  // `LiteralPtr` accepts either Literal or const Literal*
 | 
			
		||||
  // type.
 | 
			
		||||
  template <typename LiteralPtr>
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> Evaluate(
 | 
			
		||||
      const HloModule& module, absl::Span<const LiteralPtr> arg_literals);
 | 
			
		||||
  StatusOr<Literal> Evaluate(const HloModule& module,
 | 
			
		||||
                             absl::Span<const LiteralPtr> arg_literals);
 | 
			
		||||
 | 
			
		||||
  // Evaluates an HLO computation and an array of pointers to literals.
 | 
			
		||||
  // Returns the evaluated result as a literal if successful.
 | 
			
		||||
@ -69,12 +69,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
 | 
			
		||||
  // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number
 | 
			
		||||
  // 1 in this computation. The input literals array will then have its first
 | 
			
		||||
  // literal map to Parameter0 and the second map to Parameter1.
 | 
			
		||||
  // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
 | 
			
		||||
  // `LiteralPtr` accepts either Literal or const Literal*
 | 
			
		||||
  // type.
 | 
			
		||||
  template <typename LiteralPtr>
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> Evaluate(
 | 
			
		||||
      const HloComputation& computation,
 | 
			
		||||
      absl::Span<const LiteralPtr> arg_literals);
 | 
			
		||||
  StatusOr<Literal> Evaluate(const HloComputation& computation,
 | 
			
		||||
                             absl::Span<const LiteralPtr> arg_literals);
 | 
			
		||||
 | 
			
		||||
  // Evaluates a single HLO instruction and an array of pointers to literals.
 | 
			
		||||
  // Return the evaluated result as literal if successful.
 | 
			
		||||
@ -82,42 +81,43 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
 | 
			
		||||
  // 1. argument literals correspond to the input instruction's parameters in
 | 
			
		||||
  // their post-ordering.
 | 
			
		||||
  // 2. the instruction's operands must be of either Parameter or Constant type.
 | 
			
		||||
  // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
 | 
			
		||||
  // `LiteralPtr` accepts either Literal or const Literal*
 | 
			
		||||
  // type.
 | 
			
		||||
  template <typename LiteralPtr>
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> Evaluate(
 | 
			
		||||
      HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals);
 | 
			
		||||
  StatusOr<Literal> Evaluate(HloInstruction* instruction,
 | 
			
		||||
                             absl::Span<const LiteralPtr> arg_literals);
 | 
			
		||||
 | 
			
		||||
  // Evaluates a single HLO instruction with constant operands.
 | 
			
		||||
  // Returns the evaluated result as literal if successful.
 | 
			
		||||
  // Precondition:
 | 
			
		||||
  // 1. all operands of the input instruction are constants.
 | 
			
		||||
  // 2. the instruction is not a Parameter operation.
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> Evaluate(HloInstruction* instruction);
 | 
			
		||||
  StatusOr<Literal> Evaluate(HloInstruction* instruction);
 | 
			
		||||
 | 
			
		||||
  // Same as Evaluate, except returning nullptr on error.
 | 
			
		||||
  std::unique_ptr<Literal> TryEvaluate(HloInstruction* instruction);
 | 
			
		||||
  // Same as Evaluate, except returning false on error and accepts an output
 | 
			
		||||
  // pointer.
 | 
			
		||||
  bool TryEvaluate(HloInstruction* instruction, Literal* result);
 | 
			
		||||
 | 
			
		||||
  // Evaluates a single HLO instruction, substituting the given literals for
 | 
			
		||||
  // some of the instruction's operands.
 | 
			
		||||
  //
 | 
			
		||||
  // For example, given instruction = op(A, B, C) and the map
 | 
			
		||||
  // {A = x, C = y}, this evaluates op(x, B, y).
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> EvaluateWithSubstitutions(
 | 
			
		||||
  StatusOr<Literal> EvaluateWithSubstitutions(
 | 
			
		||||
      const HloInstruction* instruction,
 | 
			
		||||
      const std::unordered_map<const HloInstruction*, const Literal*>&
 | 
			
		||||
          substitutions);
 | 
			
		||||
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseBinaryOp(
 | 
			
		||||
      HloOpcode opcode, const Literal& lhs, const Literal& rhs);
 | 
			
		||||
  StatusOr<Literal> EvaluateElementwiseBinaryOp(HloOpcode opcode,
 | 
			
		||||
                                                const Literal& lhs,
 | 
			
		||||
                                                const Literal& rhs);
 | 
			
		||||
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseUnaryOp(
 | 
			
		||||
      HloOpcode opcode, const Literal& operand);
 | 
			
		||||
  StatusOr<Literal> EvaluateElementwiseUnaryOp(HloOpcode opcode,
 | 
			
		||||
                                               const Literal& operand);
 | 
			
		||||
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> EvaluateDotOp(
 | 
			
		||||
      const DotDimensionNumbers& dim_numbers,
 | 
			
		||||
      const PrecisionConfig& precision_config, const Literal& lhs,
 | 
			
		||||
      const Literal& rhs);
 | 
			
		||||
  StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers,
 | 
			
		||||
                                  const PrecisionConfig& precision_config,
 | 
			
		||||
                                  const Literal& lhs, const Literal& rhs);
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  // Make HloEvaluatorTypedVisitor a friend because it is logically part of this
 | 
			
		||||
@ -197,7 +197,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
 | 
			
		||||
    auto it = evaluated_.find(hlo);
 | 
			
		||||
    CHECK(it != evaluated_.end())
 | 
			
		||||
        << "could not find evaluated value for: " << hlo->ToString();
 | 
			
		||||
    return *(it->second);
 | 
			
		||||
    return it->second;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Tracks the HLO instruction and its evaluated literal result.
 | 
			
		||||
@ -205,12 +205,13 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
 | 
			
		||||
  // that are no longer a parent for any other subsequent instruction in
 | 
			
		||||
  // post-orderring.
 | 
			
		||||
  // Must be cleared for each evaluation.
 | 
			
		||||
  tensorflow::gtl::FlatMap<const HloInstruction*, std::unique_ptr<Literal>>
 | 
			
		||||
      evaluated_;
 | 
			
		||||
  // Storing Literal in place require the container to have pointer stability so
 | 
			
		||||
  // we cannot use FlatMap any more.
 | 
			
		||||
  std::unordered_map<const HloInstruction*, Literal> evaluated_;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  template <typename ReturnT, typename NativeT>
 | 
			
		||||
  static StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
 | 
			
		||||
  static StatusOr<Literal> ElementWiseUnaryOpImpl(
 | 
			
		||||
      HloInstruction* instruction,
 | 
			
		||||
      const std::function<ReturnT(NativeT)>& unary_op,
 | 
			
		||||
      const Literal& operand_literal) {
 | 
			
		||||
@ -227,9 +228,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
 | 
			
		||||
          ShapeUtil::HumanString(operand->shape()));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto result = absl::make_unique<Literal>(shape);
 | 
			
		||||
    Literal result(shape);
 | 
			
		||||
    TF_RETURN_IF_ERROR(
 | 
			
		||||
        result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
 | 
			
		||||
        result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
 | 
			
		||||
          return unary_op(operand_literal.Get<NativeT>(multi_index));
 | 
			
		||||
        }));
 | 
			
		||||
    return std::move(result);
 | 
			
		||||
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -246,15 +246,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
  Status HandleConvert(HloInstruction* convert) override {
 | 
			
		||||
    const HloInstruction* operand = convert->operand(0);
 | 
			
		||||
    TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(Literal result,
 | 
			
		||||
                        parent_->GetEvaluatedLiteralFor(operand).Convert(
 | 
			
		||||
                            convert->shape().element_type()));
 | 
			
		||||
 | 
			
		||||
    if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
 | 
			
		||||
    if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) {
 | 
			
		||||
      parent_->evaluated_[convert] = std::move(result);
 | 
			
		||||
    } else {
 | 
			
		||||
      parent_->evaluated_[convert] =
 | 
			
		||||
          result->Relayout(convert->shape().layout());
 | 
			
		||||
      parent_->evaluated_[convert] = result.Relayout(convert->shape().layout());
 | 
			
		||||
    }
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
@ -262,15 +261,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
  Status HandleBitcastConvert(HloInstruction* convert) override {
 | 
			
		||||
    const HloInstruction* operand = convert->operand(0);
 | 
			
		||||
    TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(Literal result,
 | 
			
		||||
                        parent_->GetEvaluatedLiteralFor(operand).BitcastConvert(
 | 
			
		||||
                            convert->shape().element_type()));
 | 
			
		||||
 | 
			
		||||
    if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
 | 
			
		||||
    if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) {
 | 
			
		||||
      parent_->evaluated_[convert] = std::move(result);
 | 
			
		||||
    } else {
 | 
			
		||||
      parent_->evaluated_[convert] =
 | 
			
		||||
          result->Relayout(convert->shape().layout());
 | 
			
		||||
      parent_->evaluated_[convert] = result.Relayout(convert->shape().layout());
 | 
			
		||||
    }
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
@ -978,10 +976,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
        << ShapeUtil::HumanString(inferred_return_shape);
 | 
			
		||||
 | 
			
		||||
    const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
 | 
			
		||||
    auto result = absl::make_unique<Literal>(result_shape);
 | 
			
		||||
    Literal result(result_shape);
 | 
			
		||||
 | 
			
		||||
    TF_RETURN_IF_ERROR(
 | 
			
		||||
        result->Populate<ReturnT>([&](absl::Span<const int64> out_index) {
 | 
			
		||||
        result.Populate<ReturnT>([&](absl::Span<const int64> out_index) {
 | 
			
		||||
          std::vector<int64> from_index(out_index.begin(), out_index.end());
 | 
			
		||||
          for (const int64 dim : reverse_dimensions) {
 | 
			
		||||
            from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim];
 | 
			
		||||
@ -1157,8 +1155,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
      return static_cast<ReturnT>(result_val);
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    auto result = absl::make_unique<Literal>(result_shape);
 | 
			
		||||
    TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func));
 | 
			
		||||
    Literal result(result_shape);
 | 
			
		||||
    TF_RETURN_IF_ERROR(result.PopulateParallel<ReturnT>(func));
 | 
			
		||||
 | 
			
		||||
    parent_->evaluated_[conv] = std::move(result);
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
@ -1231,9 +1229,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto result = absl::make_unique<Literal>(dot->shape());
 | 
			
		||||
    Literal result(dot->shape());
 | 
			
		||||
    TF_RETURN_IF_ERROR(
 | 
			
		||||
        result->Populate<ReturnT>([&](absl::Span<const int64> result_index) {
 | 
			
		||||
        result.Populate<ReturnT>([&](absl::Span<const int64> result_index) {
 | 
			
		||||
          ElementwiseT result_val = static_cast<ElementwiseT>(0);
 | 
			
		||||
 | 
			
		||||
          for (int64 i = 0; i < result_index.size(); i++) {
 | 
			
		||||
@ -1280,8 +1278,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
    // Create new HLO of padded shape with padding value.
 | 
			
		||||
    ReturnT scalar =
 | 
			
		||||
        parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({});
 | 
			
		||||
    auto result = absl::make_unique<Literal>(pad->shape());
 | 
			
		||||
    TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
 | 
			
		||||
    Literal result(pad->shape());
 | 
			
		||||
    TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
 | 
			
		||||
        [&scalar](absl::Span<const int64> multi_index) { return scalar; }));
 | 
			
		||||
 | 
			
		||||
    const Literal& evaluated_operand =
 | 
			
		||||
@ -1289,7 +1287,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
 | 
			
		||||
    std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()),
 | 
			
		||||
                                   0);
 | 
			
		||||
    std::vector<int64> target_index(ShapeUtil::Rank(result->shape()), 0);
 | 
			
		||||
    std::vector<int64> target_index(ShapeUtil::Rank(result.shape()), 0);
 | 
			
		||||
 | 
			
		||||
    // Loop through each element of the operand, assign them to the
 | 
			
		||||
    // corresponding index of the resulting padded literal.
 | 
			
		||||
@ -1311,8 +1309,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
          return true;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      result->Set<ReturnT>(target_index,
 | 
			
		||||
                           evaluated_operand.Get<ReturnT>(input_index));
 | 
			
		||||
      result.Set<ReturnT>(target_index,
 | 
			
		||||
                          evaluated_operand.Get<ReturnT>(input_index));
 | 
			
		||||
      return true;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
@ -1439,16 +1437,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> MapImpl(HloInstruction* map) {
 | 
			
		||||
  StatusOr<Literal> MapImpl(HloInstruction* map) {
 | 
			
		||||
    auto operands = map->operands();
 | 
			
		||||
    HloComputation* computation = map->to_apply();
 | 
			
		||||
 | 
			
		||||
    auto result = absl::make_unique<Literal>(map->shape());
 | 
			
		||||
    Literal result(map->shape());
 | 
			
		||||
 | 
			
		||||
    HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
 | 
			
		||||
    TF_RETURN_IF_ERROR(
 | 
			
		||||
        result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
 | 
			
		||||
          std::vector<std::unique_ptr<Literal>> arg_literals;
 | 
			
		||||
        result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
 | 
			
		||||
          std::vector<Literal> arg_literals;
 | 
			
		||||
          arg_literals.reserve(operands.size());
 | 
			
		||||
 | 
			
		||||
          // Construct scalar literal parameters to be passed to the map
 | 
			
		||||
@ -1463,16 +1461,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
            arg_literals.push_back(std::move(curr_val_literal));
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          std::unique_ptr<Literal> computed_result =
 | 
			
		||||
              embedded_evaluator
 | 
			
		||||
                  .Evaluate<std::unique_ptr<Literal>>(*computation,
 | 
			
		||||
                                                      arg_literals)
 | 
			
		||||
          Literal computed_result =
 | 
			
		||||
              embedded_evaluator.Evaluate<Literal>(*computation, arg_literals)
 | 
			
		||||
                  .ConsumeValueOrDie();
 | 
			
		||||
          // Clear visit states so that the we can use the evaluate again on
 | 
			
		||||
          // the same computation.
 | 
			
		||||
          embedded_evaluator.ResetVisitStates();
 | 
			
		||||
 | 
			
		||||
          return computed_result->Get<ReturnT>({});
 | 
			
		||||
          return computed_result.Get<ReturnT>({});
 | 
			
		||||
        }));
 | 
			
		||||
    return std::move(result);
 | 
			
		||||
  }
 | 
			
		||||
@ -1557,9 +1553,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
                [](const ReturnT& a, const ReturnT& b) {
 | 
			
		||||
                  return SafeLess<ReturnT>(a, b);
 | 
			
		||||
                });
 | 
			
		||||
      auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
 | 
			
		||||
      result_literal->PopulateR1(absl::Span<const ReturnT>(result_data));
 | 
			
		||||
      VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
 | 
			
		||||
      Literal result_literal(keys_literal.shape());
 | 
			
		||||
      result_literal.PopulateR1(absl::Span<const ReturnT>(result_data));
 | 
			
		||||
      VLOG(3) << "HandleSort result_literal: " << result_literal.ToString();
 | 
			
		||||
      return result_literal;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
@ -1568,16 +1564,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
    } else {
 | 
			
		||||
      // For R2 sort, the desired semantics are to sort each matrix row
 | 
			
		||||
      // independently.
 | 
			
		||||
      auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
 | 
			
		||||
      Literal result_literal(keys_literal.shape());
 | 
			
		||||
      int64 r1_length = keys->shape().dimensions(1);
 | 
			
		||||
      for (int64 row = 0; row < keys->shape().dimensions(0); ++row) {
 | 
			
		||||
        TF_ASSIGN_OR_RETURN(auto r1_slice,
 | 
			
		||||
                            keys_literal.Slice({row, 0}, {row + 1, r1_length})
 | 
			
		||||
                                ->Reshape({r1_length}));
 | 
			
		||||
        auto r1_result = sort_r1(*r1_slice);
 | 
			
		||||
        TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length}));
 | 
			
		||||
        TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
 | 
			
		||||
            *r1_result, {0, 0}, {row, 0}, {1, r1_length}));
 | 
			
		||||
                                .Reshape({r1_length}));
 | 
			
		||||
        auto r1_result = sort_r1(r1_slice);
 | 
			
		||||
        TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length}));
 | 
			
		||||
        TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
 | 
			
		||||
            r1_result, {0, 0}, {row, 0}, {1, r1_length}));
 | 
			
		||||
      }
 | 
			
		||||
      parent_->evaluated_[sort] = std::move(result_literal);
 | 
			
		||||
    }
 | 
			
		||||
@ -1651,9 +1647,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
 | 
			
		||||
    absl::InlinedVector<std::unique_ptr<Literal>, 1> results(num_args);
 | 
			
		||||
    absl::InlinedVector<Literal, 1> results(num_args);
 | 
			
		||||
    for (int64 i = 0; i < num_args; ++i) {
 | 
			
		||||
      results[i] = absl::make_unique<Literal>(result_shape);
 | 
			
		||||
      results[i] = Literal(result_shape);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Status eval_status;
 | 
			
		||||
@ -1667,7 +1663,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (int64 input = 0; input < num_args; ++input) {
 | 
			
		||||
      TF_RETURN_IF_ERROR(results[input]->Populate<ReturnT>(
 | 
			
		||||
      TF_RETURN_IF_ERROR(results[input].Populate<ReturnT>(
 | 
			
		||||
          [&](absl::Span<const int64> multi_index) {
 | 
			
		||||
            if (!eval_status.ok()) {
 | 
			
		||||
              return init_scalars[input];
 | 
			
		||||
@ -1703,8 +1699,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
              }
 | 
			
		||||
 | 
			
		||||
              // Evaluate computation with specified literal operands.
 | 
			
		||||
              absl::InlinedVector<std::unique_ptr<Literal>, 1>
 | 
			
		||||
                  embedded_operands;
 | 
			
		||||
              absl::InlinedVector<Literal, 1> embedded_operands;
 | 
			
		||||
              for (ReturnT value : result_values) {
 | 
			
		||||
                embedded_operands.push_back(
 | 
			
		||||
                    LiteralUtil::CreateR0<ReturnT>(value));
 | 
			
		||||
@ -1717,11 +1712,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
                  embedded_operands.size());
 | 
			
		||||
              std::transform(embedded_operands.begin(), embedded_operands.end(),
 | 
			
		||||
                             embedded_operands_ptrs.begin(),
 | 
			
		||||
                             [](const std::unique_ptr<Literal>& ptr) {
 | 
			
		||||
                               return ptr.get();
 | 
			
		||||
                             });
 | 
			
		||||
                             [](Literal& literal) { return &literal; });
 | 
			
		||||
 | 
			
		||||
              TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result,
 | 
			
		||||
              TF_ASSIGN_OR_RETURN(Literal computed_result,
 | 
			
		||||
                                  embedded_evaluator.Evaluate<const Literal*>(
 | 
			
		||||
                                      *function, embedded_operands_ptrs));
 | 
			
		||||
              // Clear visit states so that we can use the evaluator again on
 | 
			
		||||
@ -1729,10 +1722,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
              embedded_evaluator.ResetVisitStates();
 | 
			
		||||
              // Assign computed result to result_val.
 | 
			
		||||
              if (!has_tuple_output) {
 | 
			
		||||
                result_values[0] = computed_result->Get<ReturnT>({});
 | 
			
		||||
                result_values[0] = computed_result.Get<ReturnT>({});
 | 
			
		||||
              } else {
 | 
			
		||||
                for (int64 i = 0; i < num_args; ++i) {
 | 
			
		||||
                  result_values[i] = computed_result->Get<ReturnT>(
 | 
			
		||||
                  result_values[i] = computed_result.Get<ReturnT>(
 | 
			
		||||
                      /*multi_index=*/{}, /*shape_index=*/{i});
 | 
			
		||||
                }
 | 
			
		||||
              }
 | 
			
		||||
@ -1748,9 +1741,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
    if (!has_tuple_output) {
 | 
			
		||||
      parent_->evaluated_[reduce] = std::move(results[0]);
 | 
			
		||||
    } else {
 | 
			
		||||
      auto tuple_result = absl::make_unique<Literal>(reduce->shape());
 | 
			
		||||
      Literal tuple_result(reduce->shape());
 | 
			
		||||
      for (int64 i = 0; i < num_args; ++i) {
 | 
			
		||||
        TF_CHECK_OK(tuple_result->MoveFrom(std::move(*results[i]), {i}));
 | 
			
		||||
        TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i}));
 | 
			
		||||
      }
 | 
			
		||||
      parent_->evaluated_[reduce] = std::move(tuple_result);
 | 
			
		||||
    }
 | 
			
		||||
@ -1781,10 +1774,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
    TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
 | 
			
		||||
    auto init_scalar = init_literal.Get<ReturnT>({});
 | 
			
		||||
 | 
			
		||||
    auto result = absl::make_unique<Literal>(select_and_scatter->shape());
 | 
			
		||||
    Literal result(select_and_scatter->shape());
 | 
			
		||||
 | 
			
		||||
    // Initialize result array with the init value.
 | 
			
		||||
    TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
 | 
			
		||||
    TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
 | 
			
		||||
        [&](absl::Span<const int64> output_index) { return init_scalar; }));
 | 
			
		||||
 | 
			
		||||
    std::vector<int64> window_dimension_sizes;
 | 
			
		||||
@ -1834,15 +1827,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
              selected_val = curr_val;
 | 
			
		||||
              selected_index = operand_index;
 | 
			
		||||
            }
 | 
			
		||||
            curr_val_literal->Set({}, curr_val);
 | 
			
		||||
            selected_val_literal->Set({}, *selected_val);
 | 
			
		||||
            std::unique_ptr<Literal> computed_result =
 | 
			
		||||
            curr_val_literal.Set({}, curr_val);
 | 
			
		||||
            selected_val_literal.Set({}, *selected_val);
 | 
			
		||||
            Literal computed_result =
 | 
			
		||||
                embedded_evaluator
 | 
			
		||||
                    .Evaluate<const Literal*>(
 | 
			
		||||
                        *select,
 | 
			
		||||
                        {selected_val_literal.get(), curr_val_literal.get()})
 | 
			
		||||
                        *select, {&selected_val_literal, &curr_val_literal})
 | 
			
		||||
                    .ConsumeValueOrDie();
 | 
			
		||||
            bool selected = !computed_result->Get<bool>({});
 | 
			
		||||
            bool selected = !computed_result.Get<bool>({});
 | 
			
		||||
            if (selected) {
 | 
			
		||||
              selected_val = curr_val;
 | 
			
		||||
              selected_index = operand_index;
 | 
			
		||||
@ -1856,16 +1848,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
            if (std::equal(operand_index.begin(), operand_index.end(),
 | 
			
		||||
                           selected_index->begin())) {
 | 
			
		||||
              auto source = source_literal.Get<ReturnT>(source_index);
 | 
			
		||||
              auto scattered = result->Get<ReturnT>(operand_index);
 | 
			
		||||
              source_literal_scatter->Set({}, source);
 | 
			
		||||
              scattered_literal->Set({}, scattered);
 | 
			
		||||
              std::unique_ptr<Literal> computed_result =
 | 
			
		||||
              auto scattered = result.Get<ReturnT>(operand_index);
 | 
			
		||||
              source_literal_scatter.Set({}, source);
 | 
			
		||||
              scattered_literal.Set({}, scattered);
 | 
			
		||||
              Literal computed_result =
 | 
			
		||||
                  embedded_evaluator
 | 
			
		||||
                      .Evaluate<const Literal*>(*scatter,
 | 
			
		||||
                                                {source_literal_scatter.get(),
 | 
			
		||||
                                                 scattered_literal.get()})
 | 
			
		||||
                      .Evaluate<const Literal*>(
 | 
			
		||||
                          *scatter,
 | 
			
		||||
                          {&source_literal_scatter, &scattered_literal})
 | 
			
		||||
                      .ConsumeValueOrDie();
 | 
			
		||||
              result->Set(operand_index, computed_result->Get<ReturnT>({}));
 | 
			
		||||
              result.Set(operand_index, computed_result.Get<ReturnT>({}));
 | 
			
		||||
              // Clear visit states so that the we can use the evaluator again
 | 
			
		||||
              // on the same computation.
 | 
			
		||||
              embedded_evaluator.ResetVisitStates();
 | 
			
		||||
@ -1916,10 +1908,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
    DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape()));
 | 
			
		||||
 | 
			
		||||
    HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
 | 
			
		||||
    auto result = absl::make_unique<Literal>(reduce_window->shape());
 | 
			
		||||
    Literal result(reduce_window->shape());
 | 
			
		||||
    // For each resulting dimension, calculate and assign computed value.
 | 
			
		||||
    TF_RETURN_IF_ERROR(
 | 
			
		||||
        result->Populate<ReturnT>([&](absl::Span<const int64> output_index) {
 | 
			
		||||
        result.Populate<ReturnT>([&](absl::Span<const int64> output_index) {
 | 
			
		||||
          ReturnT result_val = init_scalar;
 | 
			
		||||
 | 
			
		||||
          std::fill(window_index.begin(), window_index.end(), 0);
 | 
			
		||||
@ -1935,18 +1927,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
                    LiteralUtil::CreateR0<ReturnT>(curr_val);
 | 
			
		||||
                const auto result_val_literal =
 | 
			
		||||
                    LiteralUtil::CreateR0<ReturnT>(result_val);
 | 
			
		||||
                std::unique_ptr<Literal> computed_result =
 | 
			
		||||
                Literal computed_result =
 | 
			
		||||
                    embedded_evaluator
 | 
			
		||||
                        .Evaluate<const Literal*>(
 | 
			
		||||
                            *function,
 | 
			
		||||
                            {result_val_literal.get(), curr_val_literal.get()})
 | 
			
		||||
                            *function, {&result_val_literal, &curr_val_literal})
 | 
			
		||||
                        .ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
                // Clear visit states so that the we can use the evaluate again
 | 
			
		||||
                // on the same computation.
 | 
			
		||||
                embedded_evaluator.ResetVisitStates();
 | 
			
		||||
 | 
			
		||||
                result_val = computed_result->Get<ReturnT>({});
 | 
			
		||||
                result_val = computed_result.Get<ReturnT>({});
 | 
			
		||||
              });
 | 
			
		||||
 | 
			
		||||
          return result_val;
 | 
			
		||||
@ -1961,7 +1952,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
  // literal (if there is one) to `reshaped_indices`.
 | 
			
		||||
  StatusOr<std::reference_wrapper<const Literal>> ReshapedScatterIndices(
 | 
			
		||||
      int64 index_vector_dim, const Literal& indices,
 | 
			
		||||
      std::unique_ptr<Literal>* reshaped_indices) {
 | 
			
		||||
      Literal* reshaped_indices) {
 | 
			
		||||
    if (indices.shape().dimensions_size() != index_vector_dim) {
 | 
			
		||||
      return std::cref(indices);
 | 
			
		||||
    }
 | 
			
		||||
@ -1970,7 +1961,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
                                 indices.shape().dimensions().end());
 | 
			
		||||
    new_shape.push_back(1);
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape));
 | 
			
		||||
    return std::cref(**reshaped_indices);
 | 
			
		||||
    return std::cref(*reshaped_indices);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Returns an ShapeUtil::IndexIterationSpace that iterates over the update
 | 
			
		||||
@ -2230,7 +2221,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
        scatter->scatter_dimension_numbers();
 | 
			
		||||
    const Literal& operand =
 | 
			
		||||
        parent_->GetEvaluatedLiteralFor(scatter->operand(0));
 | 
			
		||||
    std::unique_ptr<Literal> reshaped_scatter_indices;
 | 
			
		||||
    Literal reshaped_scatter_indices;
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(const Literal& scatter_indices,
 | 
			
		||||
                        ReshapedScatterIndices(dim_numbers.index_vector_dim(),
 | 
			
		||||
                                               parent_->GetEvaluatedLiteralFor(
 | 
			
		||||
@ -2260,7 +2251,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
 | 
			
		||||
    // Initialize the result with the operand. This makes it easier to handle
 | 
			
		||||
    // the updates even when the indices are repeated.
 | 
			
		||||
    std::unique_ptr<Literal> result = operand.CloneToUnique();
 | 
			
		||||
    Literal result = operand.Clone();
 | 
			
		||||
    HloEvaluator embedded_evaluator;
 | 
			
		||||
    auto scatter_inner_loop_body =
 | 
			
		||||
        [&](absl::Span<const int64> update_window_index,
 | 
			
		||||
@ -2299,19 +2290,19 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      auto result_value_literal =
 | 
			
		||||
          LiteralUtil::CreateR0<ReturnT>(result->Get<ReturnT>(input_index));
 | 
			
		||||
          LiteralUtil::CreateR0<ReturnT>(result.Get<ReturnT>(input_index));
 | 
			
		||||
      auto update_value_literal =
 | 
			
		||||
          LiteralUtil::CreateR0<ReturnT>(updates.Get<ReturnT>(update_index));
 | 
			
		||||
      std::unique_ptr<Literal> updated_result =
 | 
			
		||||
      Literal updated_result =
 | 
			
		||||
          embedded_evaluator
 | 
			
		||||
              .Evaluate<const Literal*>(
 | 
			
		||||
                  *scatter->to_apply(),
 | 
			
		||||
                  {result_value_literal.get(), update_value_literal.get()})
 | 
			
		||||
                  {&result_value_literal, &update_value_literal})
 | 
			
		||||
              .ConsumeValueOrDie();
 | 
			
		||||
      // Clear visit states so that the we can use the evaluate again on the
 | 
			
		||||
      // same computation.
 | 
			
		||||
      embedded_evaluator.ResetVisitStates();
 | 
			
		||||
      result->Set<ReturnT>(input_index, updated_result->Get<ReturnT>({}));
 | 
			
		||||
      result.Set<ReturnT>(input_index, updated_result.Get<ReturnT>({}));
 | 
			
		||||
      return true;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
@ -2361,7 +2352,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
 | 
			
		||||
    auto result = LiteralUtil::CreateFromDimensions(
 | 
			
		||||
        shape.element_type(), AsInt64Slice(shape.dimensions()));
 | 
			
		||||
    TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
 | 
			
		||||
    TF_RETURN_IF_ERROR(result.Populate<ReturnT>(func));
 | 
			
		||||
    parent_->evaluated_[slice] = std::move(result);
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
@ -2575,7 +2566,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
    if (ShapeUtil::Rank(iota->shape()) > 1) {
 | 
			
		||||
      TF_ASSIGN_OR_RETURN(
 | 
			
		||||
          parent_->evaluated_[iota],
 | 
			
		||||
          result->Broadcast(iota->shape(), {iota->iota_dimension()}));
 | 
			
		||||
          result.Broadcast(iota->shape(), {iota->iota_dimension()}));
 | 
			
		||||
    } else {
 | 
			
		||||
      TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1);
 | 
			
		||||
      parent_->evaluated_[iota] = std::move(result);
 | 
			
		||||
@ -2645,9 +2636,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  template <typename IndexT>
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> DynamicSlice(
 | 
			
		||||
      const Literal& operand_literal, const Literal& start_indices_literal,
 | 
			
		||||
      const Shape& result_shape) {
 | 
			
		||||
  StatusOr<Literal> DynamicSlice(const Literal& operand_literal,
 | 
			
		||||
                                 const Literal& start_indices_literal,
 | 
			
		||||
                                 const Shape& result_shape) {
 | 
			
		||||
    auto start_indices_typed = start_indices_literal.data<IndexT>();
 | 
			
		||||
    std::vector<int64> start(start_indices_typed.begin(),
 | 
			
		||||
                             start_indices_typed.end());
 | 
			
		||||
@ -2660,9 +2651,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::vector<int64> operand_indices(start.size());
 | 
			
		||||
    auto result = absl::make_unique<Literal>(result_shape);
 | 
			
		||||
    Literal result(result_shape);
 | 
			
		||||
    TF_RETURN_IF_ERROR(
 | 
			
		||||
        result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
 | 
			
		||||
        result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
 | 
			
		||||
          for (int64 i = 0; i < operand_indices.size(); ++i) {
 | 
			
		||||
            CHECK_GE(multi_index[i] + start[i], 0);
 | 
			
		||||
            operand_indices[i] = multi_index[i] + start[i];
 | 
			
		||||
@ -2676,12 +2667,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  template <typename IndexT>
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice(
 | 
			
		||||
      const Literal& operand_literal, const Literal& update_literal,
 | 
			
		||||
      const Literal& start_indices_literal) {
 | 
			
		||||
    auto result = operand_literal.CloneToUnique();
 | 
			
		||||
  StatusOr<Literal> DynamicUpdateSlice(const Literal& operand_literal,
 | 
			
		||||
                                       const Literal& update_literal,
 | 
			
		||||
                                       const Literal& start_indices_literal) {
 | 
			
		||||
    auto result = operand_literal.Clone();
 | 
			
		||||
    auto start_indices_typed = start_indices_literal.data<IndexT>();
 | 
			
		||||
    const auto rank = ShapeUtil::Rank(result->shape());
 | 
			
		||||
    const auto rank = ShapeUtil::Rank(result.shape());
 | 
			
		||||
    std::vector<int64> start(start_indices_typed.begin(),
 | 
			
		||||
                             start_indices_typed.end());
 | 
			
		||||
    // Clamp the update start indices so the slice is in-bounds w.r.t the
 | 
			
		||||
@ -2689,15 +2680,15 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
    for (int64 i = 0; i < rank; ++i) {
 | 
			
		||||
      start[i] = std::min<int64>(
 | 
			
		||||
          std::max<int64>(0, start[i]),
 | 
			
		||||
          result->shape().dimensions(i) - update_literal.shape().dimensions(i));
 | 
			
		||||
          result.shape().dimensions(i) - update_literal.shape().dimensions(i));
 | 
			
		||||
    }
 | 
			
		||||
    std::vector<int64> result_index(rank, 0);
 | 
			
		||||
 | 
			
		||||
    auto func = [&](absl::Span<const int64> update_index) {
 | 
			
		||||
      std::transform(update_index.begin(), update_index.end(), start.begin(),
 | 
			
		||||
                     result_index.begin(), std::plus<int64>());
 | 
			
		||||
      result->Set<ReturnT>(result_index,
 | 
			
		||||
                           update_literal.Get<ReturnT>(update_index));
 | 
			
		||||
      result.Set<ReturnT>(result_index,
 | 
			
		||||
                          update_literal.Get<ReturnT>(update_index));
 | 
			
		||||
      return true;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
@ -2710,7 +2701,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
    return std::move(result);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp(
 | 
			
		||||
  StatusOr<Literal> ElementWiseUnaryOp(
 | 
			
		||||
      HloInstruction* instruction,
 | 
			
		||||
      const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
 | 
			
		||||
    const Literal& operand_literal =
 | 
			
		||||
@ -2723,7 +2714,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
    return std::move(result_literal);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp(
 | 
			
		||||
  StatusOr<Literal> ElementWiseBinaryOp(
 | 
			
		||||
      HloInstruction* instruction,
 | 
			
		||||
      const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
 | 
			
		||||
          binary_op) {
 | 
			
		||||
@ -2745,10 +2736,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
    const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
 | 
			
		||||
    const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
 | 
			
		||||
 | 
			
		||||
    auto result = absl::make_unique<Literal>(shape);
 | 
			
		||||
    Literal result(shape);
 | 
			
		||||
 | 
			
		||||
    TF_RETURN_IF_ERROR(
 | 
			
		||||
        result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
 | 
			
		||||
        result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
 | 
			
		||||
          return ConvertBinaryFunction(binary_op)(
 | 
			
		||||
              lhs_literal.Get<ReturnT>(multi_index),
 | 
			
		||||
              rhs_literal.Get<ReturnT>(multi_index));
 | 
			
		||||
@ -2757,7 +2748,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  template <typename LhsType, typename RhsType, typename EhsType>
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> ElementwiseTernaryOp(
 | 
			
		||||
  StatusOr<Literal> ElementwiseTernaryOp(
 | 
			
		||||
      HloInstruction* instruction,
 | 
			
		||||
      const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) {
 | 
			
		||||
    const auto shape = instruction->shape();
 | 
			
		||||
@ -2782,10 +2773,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
 | 
			
		||||
    const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
 | 
			
		||||
    const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs);
 | 
			
		||||
 | 
			
		||||
    auto result = absl::make_unique<Literal>(shape);
 | 
			
		||||
    Literal result(shape);
 | 
			
		||||
 | 
			
		||||
    TF_RETURN_IF_ERROR(
 | 
			
		||||
        result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
 | 
			
		||||
        result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
 | 
			
		||||
          return ternary_op(lhs_literal.Get<LhsType>(multi_index),
 | 
			
		||||
                            rhs_literal.Get<RhsType>(multi_index),
 | 
			
		||||
                            ehs_literal.Get<EhsType>(multi_index));
 | 
			
		||||
 | 
			
		||||
@ -250,7 +250,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
 | 
			
		||||
      TF_RET_CHECK(proto.has_literal());
 | 
			
		||||
      TF_ASSIGN_OR_RETURN(auto literal,
 | 
			
		||||
                          Literal::CreateFromProto(proto.literal()));
 | 
			
		||||
      instruction = CreateTrace(literal->GetR1U8AsString(), operands(0));
 | 
			
		||||
      instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    case HloOpcode::kFusion: {
 | 
			
		||||
@ -527,7 +527,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
 | 
			
		||||
    std::unique_ptr<Literal> literal) {
 | 
			
		||||
    Literal literal) {
 | 
			
		||||
  return absl::make_unique<HloConstantInstruction>(std::move(literal));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -359,8 +359,7 @@ class HloInstruction {
 | 
			
		||||
                                                         const string& name);
 | 
			
		||||
 | 
			
		||||
  // Creates a literal constant instruction.
 | 
			
		||||
  static std::unique_ptr<HloInstruction> CreateConstant(
 | 
			
		||||
      std::unique_ptr<Literal> literal);
 | 
			
		||||
  static std::unique_ptr<HloInstruction> CreateConstant(Literal literal);
 | 
			
		||||
 | 
			
		||||
  // Creates an Iota instruction.
 | 
			
		||||
  static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape,
 | 
			
		||||
 | 
			
		||||
@ -845,8 +845,8 @@ std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
 | 
			
		||||
      shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
HloConstantInstruction::HloConstantInstruction(std::unique_ptr<Literal> literal)
 | 
			
		||||
    : HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()),
 | 
			
		||||
HloConstantInstruction::HloConstantInstruction(Literal literal)
 | 
			
		||||
    : HloInstruction(HloOpcode::kConstant, literal.shape()),
 | 
			
		||||
      literal_(std::move(literal)) {}
 | 
			
		||||
 | 
			
		||||
HloConstantInstruction::HloConstantInstruction(const Shape& shape)
 | 
			
		||||
@ -854,7 +854,7 @@ HloConstantInstruction::HloConstantInstruction(const Shape& shape)
 | 
			
		||||
 | 
			
		||||
HloInstructionProto HloConstantInstruction::ToProto() const {
 | 
			
		||||
  HloInstructionProto proto = HloInstruction::ToProto();
 | 
			
		||||
  if (literal_ != nullptr) {
 | 
			
		||||
  if (literal_.has_value()) {
 | 
			
		||||
    *proto.mutable_literal() = literal_->ToProto();
 | 
			
		||||
  }
 | 
			
		||||
  return proto;
 | 
			
		||||
@ -876,7 +876,7 @@ void HloConstantInstruction::RelayoutConstant(const Layout& new_layout,
 | 
			
		||||
 | 
			
		||||
  if (!mutable_array_subshape->has_layout() ||
 | 
			
		||||
      !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
 | 
			
		||||
    literal_ = literal_->Relayout(new_layout, shape_index);
 | 
			
		||||
    *literal_ = literal_->Relayout(new_layout, shape_index);
 | 
			
		||||
    *mutable_array_subshape->mutable_layout() = new_layout;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
@ -893,7 +893,8 @@ std::unique_ptr<HloInstruction>
 | 
			
		||||
HloConstantInstruction::CloneWithNewOperandsImpl(
 | 
			
		||||
    const Shape& shape, absl::Span<HloInstruction* const> new_operands,
 | 
			
		||||
    HloCloneContext* context) const {
 | 
			
		||||
  return absl::make_unique<HloConstantInstruction>(literal_->CloneToUnique());
 | 
			
		||||
  CHECK(literal_.has_value());
 | 
			
		||||
  return absl::make_unique<HloConstantInstruction>(literal_->Clone());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
 | 
			
		||||
@ -901,7 +902,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
 | 
			
		||||
    CanonicalNameMap* canonical_name_map) const {
 | 
			
		||||
  string operands;
 | 
			
		||||
  // For constants, show the actual value in place of an empty operand list.
 | 
			
		||||
  if (literal_ != nullptr &&
 | 
			
		||||
  if (literal_.has_value() &&
 | 
			
		||||
      ((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) ||
 | 
			
		||||
       options.print_large_constants())) {
 | 
			
		||||
    // Literal::ToString emits multidimensional arrays over multiple
 | 
			
		||||
@ -936,7 +937,7 @@ HloTraceInstruction::HloTraceInstruction(const string& tag,
 | 
			
		||||
 | 
			
		||||
HloInstructionProto HloTraceInstruction::ToProto() const {
 | 
			
		||||
  HloInstructionProto proto = HloInstruction::ToProto();
 | 
			
		||||
  *proto.mutable_literal() = literal_->ToProto();
 | 
			
		||||
  *proto.mutable_literal() = literal_.ToProto();
 | 
			
		||||
  return proto;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -580,13 +580,13 @@ class HloSliceInstruction : public HloInstruction {
 | 
			
		||||
 | 
			
		||||
class HloConstantInstruction : public HloInstruction {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit HloConstantInstruction(std::unique_ptr<Literal> literal);
 | 
			
		||||
  explicit HloConstantInstruction(Literal literal);
 | 
			
		||||
  // Used when the literal is too large and dropped.
 | 
			
		||||
  explicit HloConstantInstruction(const Shape& shape);
 | 
			
		||||
  // Returns the literal associated with this instruction.
 | 
			
		||||
  const Literal& literal() const { return *literal_; }
 | 
			
		||||
  // Returns whether there is literal associated with this instruction.
 | 
			
		||||
  bool HasLiteral() const { return literal_ != nullptr; }
 | 
			
		||||
  bool HasLiteral() const { return literal_.has_value(); }
 | 
			
		||||
  // Returns a serialized representation of this instruction.
 | 
			
		||||
  HloInstructionProto ToProto() const override;
 | 
			
		||||
 | 
			
		||||
@ -610,15 +610,14 @@ class HloConstantInstruction : public HloInstruction {
 | 
			
		||||
  std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
 | 
			
		||||
      const Shape& shape, absl::Span<HloInstruction* const> new_operands,
 | 
			
		||||
      HloCloneContext* context) const override;
 | 
			
		||||
  // TODO(b/36360764): Remove unique_ptr wrapping.
 | 
			
		||||
  std::unique_ptr<Literal> literal_;
 | 
			
		||||
  absl::optional<Literal> literal_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class HloTraceInstruction : public HloInstruction {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit HloTraceInstruction(const string& tag, HloInstruction* operand);
 | 
			
		||||
  // Returns a tag to be used in tracing.
 | 
			
		||||
  string TracingTag() const { return literal_->GetR1U8AsString(); }
 | 
			
		||||
  string TracingTag() const { return literal_.GetR1U8AsString(); }
 | 
			
		||||
  // Returns a serialized representation of this instruction.
 | 
			
		||||
  HloInstructionProto ToProto() const override;
 | 
			
		||||
 | 
			
		||||
@ -631,8 +630,7 @@ class HloTraceInstruction : public HloInstruction {
 | 
			
		||||
  std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
 | 
			
		||||
      const Shape& shape, absl::Span<HloInstruction* const> new_operands,
 | 
			
		||||
      HloCloneContext* context) const override;
 | 
			
		||||
  // TODO(b/36360764): Remove unique_ptr wrapping.
 | 
			
		||||
  std::unique_ptr<Literal> literal_;
 | 
			
		||||
  Literal literal_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class HloFusionInstruction : public HloInstruction {
 | 
			
		||||
 | 
			
		||||
@ -105,16 +105,13 @@ class HloParser {
 | 
			
		||||
                            string* root_name);
 | 
			
		||||
  bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
 | 
			
		||||
  bool ParseControlPredecessors(HloInstruction* instruction);
 | 
			
		||||
  bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
 | 
			
		||||
  bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
 | 
			
		||||
  bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
 | 
			
		||||
                            const Shape& shape);
 | 
			
		||||
  bool ParseDenseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
 | 
			
		||||
  bool ParseSparseLiteral(std::unique_ptr<Literal>* literal,
 | 
			
		||||
                          const Shape& shape);
 | 
			
		||||
  bool ParseLiteral(Literal* literal, const Shape& shape);
 | 
			
		||||
  bool ParseTupleLiteral(Literal* literal, const Shape& shape);
 | 
			
		||||
  bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
 | 
			
		||||
  bool ParseDenseLiteral(Literal* literal, const Shape& shape);
 | 
			
		||||
  bool ParseSparseLiteral(Literal* literal, const Shape& shape);
 | 
			
		||||
  template <typename LiteralNativeT>
 | 
			
		||||
  bool ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
 | 
			
		||||
                                const Shape& shape);
 | 
			
		||||
  bool ParseSparseLiteralHelper(Literal* literal, const Shape& shape);
 | 
			
		||||
 | 
			
		||||
  // Sets the sub-value of literal at the given index to the given value. The
 | 
			
		||||
  // literal's shape must have the default layout.
 | 
			
		||||
@ -577,7 +574,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    case HloOpcode::kConstant: {
 | 
			
		||||
      std::unique_ptr<Literal> literal;
 | 
			
		||||
      Literal literal;
 | 
			
		||||
      if (!ParseToken(TokKind::kLparen,
 | 
			
		||||
                      "expects '(' before constant literal") ||
 | 
			
		||||
          !ParseLiteral(&literal, shape) ||
 | 
			
		||||
@ -1810,8 +1807,7 @@ bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) {
 | 
			
		||||
// literal
 | 
			
		||||
//  ::= tuple
 | 
			
		||||
//  ::= non_tuple
 | 
			
		||||
bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
 | 
			
		||||
                             const Shape& shape) {
 | 
			
		||||
bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) {
 | 
			
		||||
  return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape)
 | 
			
		||||
                                   : ParseNonTupleLiteral(literal, shape);
 | 
			
		||||
}
 | 
			
		||||
@ -1821,8 +1817,7 @@ bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
 | 
			
		||||
// literal_list
 | 
			
		||||
//  ::= /*empty*/
 | 
			
		||||
//  ::= literal (',' literal)*
 | 
			
		||||
bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
 | 
			
		||||
                                  const Shape& shape) {
 | 
			
		||||
bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) {
 | 
			
		||||
  if (!EatShapeAndCheckCompatible(shape)) {
 | 
			
		||||
    return TokenError(StrCat("expects tuple constant in shape ",
 | 
			
		||||
                             ShapeUtil::HumanString(shape)));
 | 
			
		||||
@ -1830,8 +1825,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
 | 
			
		||||
  if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  std::vector<std::unique_ptr<Literal>> elements(
 | 
			
		||||
      ShapeUtil::TupleElementCount(shape));
 | 
			
		||||
  std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape));
 | 
			
		||||
 | 
			
		||||
  if (lexer_.GetKind() == TokKind::kRparen) {
 | 
			
		||||
    // empty
 | 
			
		||||
@ -1857,8 +1851,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
 | 
			
		||||
//   ::= rank01
 | 
			
		||||
//   ::= rank2345
 | 
			
		||||
// rank2345 ::= shape sparse_or_nested_array
 | 
			
		||||
bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
 | 
			
		||||
                                     const Shape& shape) {
 | 
			
		||||
bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
 | 
			
		||||
  if (LayoutUtil::IsSparseArray(shape)) {
 | 
			
		||||
    return ParseSparseLiteral(literal, shape);
 | 
			
		||||
  }
 | 
			
		||||
@ -1867,8 +1860,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
 | 
			
		||||
  return ParseDenseLiteral(literal, shape);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
 | 
			
		||||
                                  const Shape& shape) {
 | 
			
		||||
bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) {
 | 
			
		||||
  const tensorflow::int64 rank = ShapeUtil::Rank(shape);
 | 
			
		||||
  if (rank > 1 && !EatShapeAndCheckCompatible(shape)) {
 | 
			
		||||
    return false;
 | 
			
		||||
@ -1962,7 +1954,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
 | 
			
		||||
          // TODO(congliu): bool type literals with rank >= 1 are actually
 | 
			
		||||
          // printed in a compact form instead of "true" or "false". Fix that.
 | 
			
		||||
          if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true,
 | 
			
		||||
                                 linear_index++, literal->get())) {
 | 
			
		||||
                                 linear_index++, literal)) {
 | 
			
		||||
            return false;
 | 
			
		||||
          }
 | 
			
		||||
          lexer_.Lex();
 | 
			
		||||
@ -1973,7 +1965,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
 | 
			
		||||
            return Error(loc, StrCat("expects integer for primitive type: ",
 | 
			
		||||
                                     PrimitiveType_Name(shape.element_type())));
 | 
			
		||||
          }
 | 
			
		||||
          if (!SetValueInLiteral(value, linear_index++, literal->get())) {
 | 
			
		||||
          if (!SetValueInLiteral(value, linear_index++, literal)) {
 | 
			
		||||
            return false;
 | 
			
		||||
          }
 | 
			
		||||
        } else if (primitive_util::IsFloatingPointType(shape.element_type())) {
 | 
			
		||||
@ -1984,7 +1976,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
 | 
			
		||||
                loc, StrCat("expect floating point value for primitive type: ",
 | 
			
		||||
                            PrimitiveType_Name(shape.element_type())));
 | 
			
		||||
          }
 | 
			
		||||
          if (!SetValueInLiteral(value, linear_index++, literal->get())) {
 | 
			
		||||
          if (!SetValueInLiteral(value, linear_index++, literal)) {
 | 
			
		||||
            return false;
 | 
			
		||||
          }
 | 
			
		||||
        } else {
 | 
			
		||||
@ -1996,12 +1988,11 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
 | 
			
		||||
    }  // end of switch
 | 
			
		||||
  } while (nest_level > 0);
 | 
			
		||||
 | 
			
		||||
  *literal = (*literal)->Relayout(shape.layout());
 | 
			
		||||
  *literal = literal->Relayout(shape.layout());
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
 | 
			
		||||
                                   const Shape& shape) {
 | 
			
		||||
bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) {
 | 
			
		||||
  if (!EatShapeAndCheckCompatible(shape)) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
@ -2041,13 +2032,12 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename LiteralNativeT>
 | 
			
		||||
bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
 | 
			
		||||
                                         const Shape& shape) {
 | 
			
		||||
bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) {
 | 
			
		||||
  std::vector<tensorflow::int64> index;
 | 
			
		||||
 | 
			
		||||
  tensorflow::int64 rank = ShapeUtil::Rank(shape);
 | 
			
		||||
 | 
			
		||||
  *literal = absl::make_unique<Literal>(shape);
 | 
			
		||||
  *literal = Literal(shape);
 | 
			
		||||
 | 
			
		||||
  if (!ParseToken(TokKind::kLbrace,
 | 
			
		||||
                  "expects '{' at the beginning of a sparse literal")) {
 | 
			
		||||
@ -2121,7 +2111,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if ((*literal)->sparse_element_count() + 1 ==
 | 
			
		||||
    if (literal->sparse_element_count() + 1 ==
 | 
			
		||||
        LayoutUtil::MaxSparseElements(shape.layout())) {
 | 
			
		||||
      return Error(
 | 
			
		||||
          lexer_.GetLoc(),
 | 
			
		||||
@ -2129,10 +2119,10 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
 | 
			
		||||
                 ShapeUtil::HumanStringWithLayout(shape)));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    (*literal)->AppendSparseElement(index, value);
 | 
			
		||||
    literal->AppendSparseElement(index, value);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  (*literal)->SortSparseElements();
 | 
			
		||||
  literal->SortSparseElements();
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -118,16 +118,16 @@ StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
 | 
			
		||||
    const absl::Span<const std::unique_ptr<Literal>> literals) {
 | 
			
		||||
    const absl::Span<const Literal> literals) {
 | 
			
		||||
  std::vector<const Literal*> literal_pointers;
 | 
			
		||||
  literal_pointers.reserve(literals.size());
 | 
			
		||||
  for (const auto& literal : literals) {
 | 
			
		||||
    literal_pointers.push_back(literal.get());
 | 
			
		||||
    literal_pointers.push_back(&literal);
 | 
			
		||||
  }
 | 
			
		||||
  return TransferLiteralsToDevice(literal_pointers);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice(
 | 
			
		||||
StatusOr<Literal> HloRunner::TransferLiteralFromDevice(
 | 
			
		||||
    const ShapedBuffer& buffer) {
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(
 | 
			
		||||
      auto stream, backend().BorrowStream(backend().default_stream_executor()));
 | 
			
		||||
@ -135,7 +135,7 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice(
 | 
			
		||||
                                                                 buffer);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
 | 
			
		||||
StatusOr<Literal> HloRunner::Execute(
 | 
			
		||||
    std::unique_ptr<HloModule> module,
 | 
			
		||||
    const absl::Span<const Literal* const> arguments, bool run_hlo_passes,
 | 
			
		||||
    ExecutionProfile* profile) {
 | 
			
		||||
@ -150,15 +150,15 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
 | 
			
		||||
  return TransferLiteralFromDevice(result);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
 | 
			
		||||
    std::unique_ptr<HloModule> module,
 | 
			
		||||
    const absl::Span<const std::unique_ptr<Literal>> arguments,
 | 
			
		||||
    bool run_hlo_passes, ExecutionProfile* profile) {
 | 
			
		||||
StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
 | 
			
		||||
                                     const absl::Span<const Literal> arguments,
 | 
			
		||||
                                     bool run_hlo_passes,
 | 
			
		||||
                                     ExecutionProfile* profile) {
 | 
			
		||||
  // Construct a vector of plain pointers for the arguments.
 | 
			
		||||
  std::vector<const Literal*> argument_pointers;
 | 
			
		||||
  argument_pointers.reserve(arguments.size());
 | 
			
		||||
  for (const auto& argument : arguments) {
 | 
			
		||||
    argument_pointers.push_back(argument.get());
 | 
			
		||||
    argument_pointers.push_back(&argument);
 | 
			
		||||
  }
 | 
			
		||||
  return Execute(
 | 
			
		||||
      /*module=*/std::move(module),
 | 
			
		||||
@ -204,7 +204,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
 | 
			
		||||
      /*profile=*/profile);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
 | 
			
		||||
StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
 | 
			
		||||
    std::unique_ptr<HloModule> module,
 | 
			
		||||
    const ReplicatedExecuteOptions& options) {
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(
 | 
			
		||||
@ -290,9 +290,9 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
 | 
			
		||||
        VLOG(1) << "Starting outfeed on device " << device;
 | 
			
		||||
        for (int64 step = 1;
 | 
			
		||||
             options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
 | 
			
		||||
          auto literal = absl::make_unique<Literal>();
 | 
			
		||||
          Literal literal;
 | 
			
		||||
          TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
 | 
			
		||||
              executor, options.outfeed_shape, literal.get()));
 | 
			
		||||
              executor, options.outfeed_shape, &literal));
 | 
			
		||||
          if (options.outfeed_values != nullptr) {
 | 
			
		||||
            options.outfeed_values->push_back(std::move(literal));
 | 
			
		||||
          }
 | 
			
		||||
@ -310,10 +310,10 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
 | 
			
		||||
                                                   argument_buffer_slices));
 | 
			
		||||
  LOG(INFO) << "Replicated execution terminated";
 | 
			
		||||
 | 
			
		||||
  std::vector<std::unique_ptr<Literal>> exec_results;
 | 
			
		||||
  std::vector<Literal> exec_results;
 | 
			
		||||
  for (int64 i = 0; i < options.num_replicas; ++i) {
 | 
			
		||||
    TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone());
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(Literal literal,
 | 
			
		||||
                        backend().transfer_manager()->TransferLiteralFromDevice(
 | 
			
		||||
                            streams[i].get(), results[i]));
 | 
			
		||||
    exec_results.push_back(std::move(literal));
 | 
			
		||||
 | 
			
		||||
@ -72,7 +72,7 @@ class HloRunner {
 | 
			
		||||
 | 
			
		||||
    // A pointer to a vector where the outfeed values will be stored. If
 | 
			
		||||
    // nullptr, the values will be read and discarded.
 | 
			
		||||
    std::vector<std::unique_ptr<Literal>>* outfeed_values = nullptr;
 | 
			
		||||
    std::vector<Literal>* outfeed_values = nullptr;
 | 
			
		||||
 | 
			
		||||
    // Whether the HLO passes should be run on the input module. Usually
 | 
			
		||||
    // saved modules are coming from after the HLO pass pipeline, so triggering
 | 
			
		||||
@ -106,24 +106,23 @@ class HloRunner {
 | 
			
		||||
  StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
 | 
			
		||||
      const absl::Span<const Literal* const> literals);
 | 
			
		||||
  StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
 | 
			
		||||
      const absl::Span<const std::unique_ptr<Literal>> literals);
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
 | 
			
		||||
      const ShapedBuffer& buffer);
 | 
			
		||||
      const absl::Span<const Literal> literals);
 | 
			
		||||
  StatusOr<Literal> TransferLiteralFromDevice(const ShapedBuffer& buffer);
 | 
			
		||||
 | 
			
		||||
  // Executes the given module with given literals as input and returns the
 | 
			
		||||
  // result as a Literal.
 | 
			
		||||
  //
 | 
			
		||||
  // If run_hlo_passes is false, the module will be executed without Hlo
 | 
			
		||||
  // optimization.
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> Execute(
 | 
			
		||||
      std::unique_ptr<HloModule> module,
 | 
			
		||||
      const absl::Span<const Literal* const> arguments,
 | 
			
		||||
      bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
 | 
			
		||||
  StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
 | 
			
		||||
                            const absl::Span<const Literal* const> arguments,
 | 
			
		||||
                            bool run_hlo_passes = true,
 | 
			
		||||
                            ExecutionProfile* profile = nullptr);
 | 
			
		||||
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> Execute(
 | 
			
		||||
      std::unique_ptr<HloModule> module,
 | 
			
		||||
      const absl::Span<const std::unique_ptr<Literal>> arguments,
 | 
			
		||||
      bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
 | 
			
		||||
  StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
 | 
			
		||||
                            const absl::Span<const Literal> arguments,
 | 
			
		||||
                            bool run_hlo_passes = true,
 | 
			
		||||
                            ExecutionProfile* profile = nullptr);
 | 
			
		||||
 | 
			
		||||
  // As Execute(), but accepts and returns device buffers instead of host
 | 
			
		||||
  // buffers.
 | 
			
		||||
@ -140,7 +139,7 @@ class HloRunner {
 | 
			
		||||
  // Executes a given HLO module into a set of replicas, and returns a map
 | 
			
		||||
  // with the replica number as key, and the corresponding returned literal as
 | 
			
		||||
  // value.
 | 
			
		||||
  StatusOr<std::vector<std::unique_ptr<Literal>>> ExecuteReplicated(
 | 
			
		||||
  StatusOr<std::vector<Literal>> ExecuteReplicated(
 | 
			
		||||
      std::unique_ptr<HloModule> module,
 | 
			
		||||
      const ReplicatedExecuteOptions& options);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -290,8 +290,8 @@ TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) {
 | 
			
		||||
  padding_config.add_dimensions()->set_interior_padding(-1);
 | 
			
		||||
  builder.AddInstruction(HloInstruction::CreatePad(
 | 
			
		||||
      ShapeUtil::MakeShape(F32, {100}), param,
 | 
			
		||||
      builder.AddInstruction(HloInstruction::CreateConstant(
 | 
			
		||||
          LiteralUtil::Zero(F32).CloneToUnique())),
 | 
			
		||||
      builder.AddInstruction(
 | 
			
		||||
          HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
 | 
			
		||||
      padding_config));
 | 
			
		||||
 | 
			
		||||
  auto module = CreateNewModule();
 | 
			
		||||
@ -314,8 +314,8 @@ TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) {
 | 
			
		||||
  padding_config.add_dimensions()->set_interior_padding(-1);
 | 
			
		||||
  builder.AddInstruction(HloInstruction::CreatePad(
 | 
			
		||||
      ShapeUtil::MakeShape(F32, {100}), param,
 | 
			
		||||
      builder.AddInstruction(HloInstruction::CreateConstant(
 | 
			
		||||
          LiteralUtil::Zero(F32).CloneToUnique())),
 | 
			
		||||
      builder.AddInstruction(
 | 
			
		||||
          HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())),
 | 
			
		||||
      padding_config));
 | 
			
		||||
 | 
			
		||||
  auto module = CreateNewModule();
 | 
			
		||||
 | 
			
		||||
@ -918,7 +918,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
 | 
			
		||||
  // inner_broadcast_result is the Broadcast'(Const0) bit in
 | 
			
		||||
  // BinaryOp(Broadcast'(Const0), Const1)
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(
 | 
			
		||||
      std::unique_ptr<Literal> inner_broadcast_result,
 | 
			
		||||
      Literal inner_broadcast_result,
 | 
			
		||||
      broadcast_const_operand->literal().Broadcast(
 | 
			
		||||
          scalar_indexed_const->source()->shape(), new_inner_broadcast_dims));
 | 
			
		||||
 | 
			
		||||
@ -928,12 +928,12 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(
 | 
			
		||||
        literal_for_new_source,
 | 
			
		||||
        TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
 | 
			
		||||
            opcode, scalar_indexed_const->literal(), *inner_broadcast_result)));
 | 
			
		||||
            opcode, scalar_indexed_const->literal(), inner_broadcast_result)));
 | 
			
		||||
  } else {
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(
 | 
			
		||||
        literal_for_new_source,
 | 
			
		||||
        TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
 | 
			
		||||
            opcode, *inner_broadcast_result, scalar_indexed_const->literal())));
 | 
			
		||||
            opcode, inner_broadcast_result, scalar_indexed_const->literal())));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
 | 
			
		||||
 | 
			
		||||
@ -347,21 +347,19 @@ class IndexedArrayAnalysis {
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Literal* TakeOwnership(std::unique_ptr<Literal> literal) {
 | 
			
		||||
  Literal* TakeOwnership(Literal literal) {
 | 
			
		||||
    owned_literals_.push_back(std::move(literal));
 | 
			
		||||
    return owned_literals_.back().get();
 | 
			
		||||
    return &owned_literals_.back();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  StatusOr<Literal*> TakeOwnership(
 | 
			
		||||
      StatusOr<std::unique_ptr<Literal>> literal_or_error) {
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
 | 
			
		||||
                        std::move(literal_or_error));
 | 
			
		||||
  StatusOr<Literal*> TakeOwnership(StatusOr<Literal> literal_or_error) {
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error));
 | 
			
		||||
    owned_literals_.push_back(std::move(literal));
 | 
			
		||||
    return owned_literals_.back().get();
 | 
			
		||||
    return &owned_literals_.back();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::vector<std::unique_ptr<Array>> owned_tensors_;
 | 
			
		||||
  std::vector<std::unique_ptr<Literal>> owned_literals_;
 | 
			
		||||
  std::vector<Literal> owned_literals_;
 | 
			
		||||
  tensorflow::gtl::FlatMap<const HloInstruction*, Array*> cache_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -71,7 +71,7 @@ TEST_F(InlinerTest, MapMax) {
 | 
			
		||||
  // Verify execution on CPU.
 | 
			
		||||
  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 | 
			
		||||
  auto expected = LiteralUtil::CreateR1<float>({4, 3, 3, 4});
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Test that `constant` function is changed to `broadcast`.
 | 
			
		||||
@ -105,7 +105,7 @@ TEST_F(InlinerTest, MapConstant) {
 | 
			
		||||
  // Verify execution on CPU.
 | 
			
		||||
  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 | 
			
		||||
  auto expected = LiteralUtil::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(InlinerTest, MapSubtractOppositeOrder) {
 | 
			
		||||
@ -143,7 +143,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
 | 
			
		||||
  // Verify execution on CPU.
 | 
			
		||||
  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 | 
			
		||||
  auto expected = LiteralUtil::CreateR1<float>({3, 1, -1, -3});
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -73,30 +73,29 @@ StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
 | 
			
		||||
 | 
			
		||||
  // Transform the ShapedBuffer arguments into literals which the evaluator
 | 
			
		||||
  // consumes.
 | 
			
		||||
  std::vector<std::unique_ptr<Literal>> arg_literals;
 | 
			
		||||
  std::vector<Literal> arg_literals;
 | 
			
		||||
  for (int64 p = 0; p < computation->num_parameters(); ++p) {
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> arg_literal,
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(Literal arg_literal,
 | 
			
		||||
                        transfer_manager->TransferLiteralFromDevice(
 | 
			
		||||
                            run_options->stream(), *arguments[p]));
 | 
			
		||||
    arg_literals.push_back(std::move(arg_literal));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Execute the graph using the HloEvaluator.
 | 
			
		||||
  std::unique_ptr<Literal> result_literal;
 | 
			
		||||
  Literal result_literal;
 | 
			
		||||
  {
 | 
			
		||||
    tensorflow::mutex_lock lock(evaluator_lock_);
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(result_literal,
 | 
			
		||||
                        evaluator_->Evaluate<std::unique_ptr<Literal>>(
 | 
			
		||||
                            *computation, arg_literals));
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate<Literal>(
 | 
			
		||||
                                            *computation, arg_literals));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Transform the result literal back into a ShapedBuffer.
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
 | 
			
		||||
                      transfer_manager->AllocateScopedShapedBuffer(
 | 
			
		||||
                          result_literal->shape(), run_options->allocator(),
 | 
			
		||||
                          result_literal.shape(), run_options->allocator(),
 | 
			
		||||
                          executor->device_ordinal()));
 | 
			
		||||
  TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
 | 
			
		||||
      run_options->stream(), *result_literal, result));
 | 
			
		||||
      run_options->stream(), result_literal, result));
 | 
			
		||||
 | 
			
		||||
  uint64 end_micros = tensorflow::Env::Default()->NowMicros();
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -145,7 +145,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
 | 
			
		||||
        {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
 | 
			
		||||
    auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
 | 
			
		||||
        {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
 | 
			
		||||
    Shape ashape = constant_literal1->shape();
 | 
			
		||||
    Shape ashape = constant_literal1.shape();
 | 
			
		||||
 | 
			
		||||
    auto constant1 = builder.AddInstruction(
 | 
			
		||||
        HloInstruction::CreateConstant(std::move(constant_literal1)));
 | 
			
		||||
 | 
			
		||||
@ -68,9 +68,9 @@ Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
 | 
			
		||||
  module->clear_arguments();
 | 
			
		||||
  for (const ShapedBuffer* argument : arguments) {
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(
 | 
			
		||||
        std::unique_ptr<Literal> literal,
 | 
			
		||||
        Literal literal,
 | 
			
		||||
        transfer_manager->TransferLiteralFromDevice(stream, *argument));
 | 
			
		||||
    *module->add_arguments() = literal->ToProto();
 | 
			
		||||
    *module->add_arguments() = literal.ToProto();
 | 
			
		||||
  }
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
@ -80,9 +80,9 @@ Status RecordResult(const ShapedBuffer& result, se::Stream* stream,
 | 
			
		||||
                    TransferManager* transfer_manager, HloSnapshot* module) {
 | 
			
		||||
  module->clear_result();
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(
 | 
			
		||||
      std::unique_ptr<Literal> literal,
 | 
			
		||||
      Literal literal,
 | 
			
		||||
      transfer_manager->TransferLiteralFromDevice(stream, result));
 | 
			
		||||
  *module->mutable_result() = literal->ToProto();
 | 
			
		||||
  *module->mutable_result() = literal.ToProto();
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -928,16 +928,15 @@ Status Service::TransferToClient(const TransferToClientRequest* arg,
 | 
			
		||||
                                       shaped_buffer->device_ordinal()));
 | 
			
		||||
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(
 | 
			
		||||
      std::unique_ptr<Literal> result_literal,
 | 
			
		||||
      Literal result_literal,
 | 
			
		||||
      execute_backend_->transfer_manager()->TransferLiteralFromDevice(
 | 
			
		||||
          stream.get(), *shaped_buffer));
 | 
			
		||||
 | 
			
		||||
  if (LayoutUtil::LayoutsInShapesEqual(*return_shape,
 | 
			
		||||
                                       result_literal->shape())) {
 | 
			
		||||
    *result->mutable_literal() = result_literal->ToProto();
 | 
			
		||||
  if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal.shape())) {
 | 
			
		||||
    *result->mutable_literal() = result_literal.ToProto();
 | 
			
		||||
  } else {
 | 
			
		||||
    *result->mutable_literal() =
 | 
			
		||||
        result_literal->Relayout(*return_shape)->ToProto();
 | 
			
		||||
        result_literal.Relayout(*return_shape).ToProto();
 | 
			
		||||
  }
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
@ -959,9 +958,9 @@ std::unique_ptr<ShapedBuffer> CloneShapedBufferOnDevice(
 | 
			
		||||
 | 
			
		||||
Status Service::TransferToServer(const TransferToServerRequest* arg,
 | 
			
		||||
                                 TransferToServerResponse* result) {
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(Literal literal,
 | 
			
		||||
                      Literal::CreateFromProto(arg->literal()));
 | 
			
		||||
  const Shape& shape = literal->shape();
 | 
			
		||||
  const Shape& shape = literal.shape();
 | 
			
		||||
 | 
			
		||||
  std::vector<se::StreamExecutor*> replicas;
 | 
			
		||||
  if (arg->has_device_handle()) {
 | 
			
		||||
@ -983,7 +982,7 @@ Status Service::TransferToServer(const TransferToServerRequest* arg,
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor));
 | 
			
		||||
    TF_RETURN_IF_ERROR(
 | 
			
		||||
        execute_backend_->transfer_manager()->TransferLiteralToDevice(
 | 
			
		||||
            stream.get(), *literal, shaped_buffer));
 | 
			
		||||
            stream.get(), literal, shaped_buffer));
 | 
			
		||||
    replicated_buffers.emplace_back(std::move(shaped_buffer));
 | 
			
		||||
  }
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(*result->mutable_data(),
 | 
			
		||||
@ -1018,10 +1017,10 @@ Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
 | 
			
		||||
    executor = replicas[arg->replica_id()];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(Literal literal,
 | 
			
		||||
                      Literal::CreateFromProto(arg->literal()));
 | 
			
		||||
  return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
 | 
			
		||||
      executor, *literal);
 | 
			
		||||
  return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor,
 | 
			
		||||
                                                                       literal);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
 | 
			
		||||
@ -1049,8 +1048,8 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
 | 
			
		||||
 | 
			
		||||
  TF_RETURN_IF_ERROR(
 | 
			
		||||
      execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
 | 
			
		||||
          executor, arg->shape_with_layout(), *literal));
 | 
			
		||||
  *result->mutable_literal() = literal->ToProto();
 | 
			
		||||
          executor, arg->shape_with_layout(), literal));
 | 
			
		||||
  *result->mutable_literal() = literal.ToProto();
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1085,18 +1084,17 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
 | 
			
		||||
                      HloModule::CreateFromProto(arg->computation(), config));
 | 
			
		||||
 | 
			
		||||
  HloEvaluator evaluator;
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(auto result_literal,
 | 
			
		||||
                      evaluator.Evaluate<std::unique_ptr<Literal>>(
 | 
			
		||||
                          *module, /*arg_literals=*/{}));
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate<Literal>(
 | 
			
		||||
                                               *module, /*arg_literals=*/{}));
 | 
			
		||||
 | 
			
		||||
  // Since the result layout is non-effective to the Evaluator results, explicit
 | 
			
		||||
  // relayout here.
 | 
			
		||||
  //
 | 
			
		||||
  // TODO(b/77824332): Make HloEvaluator take care of the re-layout.
 | 
			
		||||
  if (arg->has_output_layout()) {
 | 
			
		||||
    result_literal = result_literal->Relayout(arg->output_layout());
 | 
			
		||||
    result_literal = result_literal.Relayout(arg->output_layout());
 | 
			
		||||
  }
 | 
			
		||||
  *result->mutable_literal() = result_literal->ToProto();
 | 
			
		||||
  *result->mutable_literal() = result_literal.ToProto();
 | 
			
		||||
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -42,9 +42,9 @@ TransferManager::GetPlatformTransferManagers() {
 | 
			
		||||
  return r;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
 | 
			
		||||
StatusOr<Literal> TransferManager::TransferLiteralFromDevice(
 | 
			
		||||
    se::Stream* stream, const ShapedBuffer& device_buffer) {
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> ret;
 | 
			
		||||
  StatusOr<Literal> ret;
 | 
			
		||||
 | 
			
		||||
  se::Stream* substream = stream->GetOrCreateSubStream();
 | 
			
		||||
  substream->ThenWaitFor(stream);
 | 
			
		||||
@ -63,7 +63,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
 | 
			
		||||
  if (!s.ok()) {
 | 
			
		||||
    return s;
 | 
			
		||||
  }
 | 
			
		||||
  return absl::make_unique<Literal>(std::move(literal));
 | 
			
		||||
  return std::move(literal);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status TransferManager::TransferLiteralFromDevice(
 | 
			
		||||
@ -99,10 +99,10 @@ Status TransferManager::TransferLiteralToDevice(
 | 
			
		||||
  return substream->BlockHostUntilDone();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
 | 
			
		||||
StatusOr<Literal> TransferManager::TransferArrayFromDevice(
 | 
			
		||||
    se::Stream* stream, const Shape& shape,
 | 
			
		||||
    const se::DeviceMemoryBase& source) {
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> ret;
 | 
			
		||||
  StatusOr<Literal> ret;
 | 
			
		||||
  // Implement the synchronous version by waiting on the asynchronous version.
 | 
			
		||||
  // Use a substream so that if we are called from a HostCallback we don't
 | 
			
		||||
  // deadlock.
 | 
			
		||||
@ -122,7 +122,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
 | 
			
		||||
  if (!s.ok()) {
 | 
			
		||||
    return s;
 | 
			
		||||
  }
 | 
			
		||||
  return absl::make_unique<Literal>(std::move(literal));
 | 
			
		||||
  return std::move(literal);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status TransferManager::TransferArrayToDevice(
 | 
			
		||||
 | 
			
		||||
@ -57,7 +57,7 @@ class TransferManager {
 | 
			
		||||
  // without waiting for any other operation on a stream to complete.
 | 
			
		||||
  //
 | 
			
		||||
  // This function should be avoided in favor of the asynchronous version below.
 | 
			
		||||
  virtual StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
 | 
			
		||||
  virtual StatusOr<Literal> TransferLiteralFromDevice(
 | 
			
		||||
      se::Stream* stream, const ShapedBuffer& device_buffer);
 | 
			
		||||
  virtual Status TransferLiteralFromDevice(
 | 
			
		||||
      se::Stream* stream, const ShapedBuffer& device_buffer,
 | 
			
		||||
@ -113,9 +113,9 @@ class TransferManager {
 | 
			
		||||
  Status TransferArrayToDeviceAsync(se::Stream* stream,
 | 
			
		||||
                                    const LiteralSlice& literal,
 | 
			
		||||
                                    const se::DeviceMemoryBase& dest);
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> TransferArrayFromDevice(
 | 
			
		||||
      se::Stream* stream, const Shape& shape,
 | 
			
		||||
      const se::DeviceMemoryBase& source);
 | 
			
		||||
  StatusOr<Literal> TransferArrayFromDevice(se::Stream* stream,
 | 
			
		||||
                                            const Shape& shape,
 | 
			
		||||
                                            const se::DeviceMemoryBase& source);
 | 
			
		||||
 | 
			
		||||
  // Transfers the given literal into the Infeed interface of the device,
 | 
			
		||||
  // using the given executor.
 | 
			
		||||
 | 
			
		||||
@ -555,10 +555,10 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) {
 | 
			
		||||
  // Construct a tuple constant and kCopy it. Verify the points-to set of the
 | 
			
		||||
  // copy correctly correctly points into the nested elements of the constant.
 | 
			
		||||
  auto builder = HloComputation::Builder(TestName());
 | 
			
		||||
  auto tuple_constant = builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateConstant(LiteralUtil::MakeTuple(
 | 
			
		||||
          {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
 | 
			
		||||
           LiteralUtil::CreateR1<float>({2.0, 42}).get()})));
 | 
			
		||||
  Literal elements[] = {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
 | 
			
		||||
                        LiteralUtil::CreateR1<float>({2.0, 42})};
 | 
			
		||||
  auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
 | 
			
		||||
      LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
 | 
			
		||||
  auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
 | 
			
		||||
      tuple_constant->shape(), HloOpcode::kCopy, tuple_constant));
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -183,8 +183,7 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
 | 
			
		||||
  HloEvaluator evaluator(/*max_loop_iterations=*/0);
 | 
			
		||||
  auto* while_init = while_op->mutable_operand(0);
 | 
			
		||||
  auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx);
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> indvar_init_result =
 | 
			
		||||
      evaluator.Evaluate(indvar_init);
 | 
			
		||||
  StatusOr<Literal> indvar_init_result = evaluator.Evaluate(indvar_init);
 | 
			
		||||
  if (!indvar_init_result.ok()) {
 | 
			
		||||
    VLOG(2) << "Couldn't evaluate induction variable init: "
 | 
			
		||||
            << indvar_init_result.status();
 | 
			
		||||
@ -197,31 +196,27 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
 | 
			
		||||
  auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
 | 
			
		||||
 | 
			
		||||
  // The initial value of the induction variable.
 | 
			
		||||
  std::unique_ptr<Literal> indvar_iter_val =
 | 
			
		||||
      std::move(indvar_init_result).ValueOrDie();
 | 
			
		||||
  Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie();
 | 
			
		||||
  for (int64 trip_count = 0; trip_count != max_value_returned + 1;
 | 
			
		||||
       ++trip_count) {
 | 
			
		||||
    auto* while_cond = while_op->while_condition();
 | 
			
		||||
    auto* while_cond_root = while_cond->root_instruction();
 | 
			
		||||
    auto* while_cond_indvar = NonConstantOperand(while_cond_root);
 | 
			
		||||
    StatusOr<std::unique_ptr<Literal>> result =
 | 
			
		||||
        evaluator.EvaluateWithSubstitutions(
 | 
			
		||||
            while_cond_root, {{while_cond_indvar, indvar_iter_val.get()}});
 | 
			
		||||
    StatusOr<Literal> result = evaluator.EvaluateWithSubstitutions(
 | 
			
		||||
        while_cond_root, {{while_cond_indvar, &indvar_iter_val}});
 | 
			
		||||
    if (!result.ok()) {
 | 
			
		||||
      VLOG(2) << "Couldn't evaluate while cond: " << result.status();
 | 
			
		||||
      return nullopt;
 | 
			
		||||
    }
 | 
			
		||||
    if (result.ValueOrDie()->data<bool>() == absl::Span<const bool>{false}) {
 | 
			
		||||
    if (result.ValueOrDie().data<bool>() == absl::Span<const bool>{false}) {
 | 
			
		||||
      VLOG(2) << "Loop has static trip count of " << trip_count;
 | 
			
		||||
      return trip_count;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Calculate the value of the induction variable after one iteration of the
 | 
			
		||||
    // loop, and check whether the while condition is true with this new value.
 | 
			
		||||
    StatusOr<std::unique_ptr<Literal>> indvar_next_result =
 | 
			
		||||
        evaluator.EvaluateWithSubstitutions(
 | 
			
		||||
            while_body_indvar_update,
 | 
			
		||||
            {{while_body_indvar, indvar_iter_val.get()}});
 | 
			
		||||
    StatusOr<Literal> indvar_next_result = evaluator.EvaluateWithSubstitutions(
 | 
			
		||||
        while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}});
 | 
			
		||||
    if (!indvar_next_result.ok()) {
 | 
			
		||||
      VLOG(2) << "Couldn't evaluate induction variable update: "
 | 
			
		||||
              << indvar_next_result.status();
 | 
			
		||||
 | 
			
		||||
@ -41,7 +41,6 @@ limitations under the License.
 | 
			
		||||
namespace xla {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ArrayElementwiseOpTest : public ClientLibraryTestBase {
 | 
			
		||||
 public:
 | 
			
		||||
  ErrorSpec error_spec_{0.0001, 0.0001};
 | 
			
		||||
@ -227,10 +226,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
 | 
			
		||||
                          0x8000000000000000LL,
 | 
			
		||||
                          0x8000000000000000LL,
 | 
			
		||||
                          1};
 | 
			
		||||
  std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
 | 
			
		||||
  auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
 | 
			
		||||
  Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
 | 
			
		||||
  auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
 | 
			
		||||
  std::unique_ptr<GlobalData> lhs_data =
 | 
			
		||||
      client_->TransferToServer(*lhs_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::vector<uint64> rhs{1,
 | 
			
		||||
                          0x7FFFFFFFFFFFFFFLL,
 | 
			
		||||
@ -241,10 +240,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
 | 
			
		||||
                          0,
 | 
			
		||||
                          1,
 | 
			
		||||
                          0x8000000000000000LL};
 | 
			
		||||
  std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
 | 
			
		||||
  auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
 | 
			
		||||
  Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
 | 
			
		||||
  auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
 | 
			
		||||
  std::unique_ptr<GlobalData> rhs_data =
 | 
			
		||||
      client_->TransferToServer(*rhs_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  Add(lhs_param, rhs_param);
 | 
			
		||||
 | 
			
		||||
@ -267,10 +266,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
 | 
			
		||||
                         1,
 | 
			
		||||
                         0,
 | 
			
		||||
                         -1};
 | 
			
		||||
  std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
 | 
			
		||||
  auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
 | 
			
		||||
  Literal lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
 | 
			
		||||
  auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
 | 
			
		||||
  std::unique_ptr<GlobalData> lhs_data =
 | 
			
		||||
      client_->TransferToServer(*lhs_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::vector<int64> rhs{-1,
 | 
			
		||||
                         0,
 | 
			
		||||
@ -280,10 +279,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
 | 
			
		||||
                         0x7FFFFFFFFFFFFFFLL,
 | 
			
		||||
                         0x7FFFFFFFFFFFFFFFLL,
 | 
			
		||||
                         0x7FFFFFFFFFFFFFFFLL};
 | 
			
		||||
  std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
 | 
			
		||||
  auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
 | 
			
		||||
  Literal rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
 | 
			
		||||
  auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
 | 
			
		||||
  std::unique_ptr<GlobalData> rhs_data =
 | 
			
		||||
      client_->TransferToServer(*rhs_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  Sub(lhs_param, rhs_param);
 | 
			
		||||
 | 
			
		||||
@ -299,16 +298,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) {
 | 
			
		||||
  XlaBuilder b(TestName());
 | 
			
		||||
 | 
			
		||||
  std::vector<uint64> lhs{static_cast<uint64>(0x8000000000000000ULL)};
 | 
			
		||||
  std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
 | 
			
		||||
  auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
 | 
			
		||||
  Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
 | 
			
		||||
  auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
 | 
			
		||||
 | 
			
		||||
  std::vector<uint64> rhs{static_cast<uint64>(0x7FFFFFFFFFFFFFFFULL)};
 | 
			
		||||
  std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
 | 
			
		||||
  auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
 | 
			
		||||
  Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
 | 
			
		||||
  auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
 | 
			
		||||
 | 
			
		||||
  Lt(lhs_param, rhs_param);
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompare(&b, {std::move(*lhs_literal), std::move(*rhs_literal)});
 | 
			
		||||
  ComputeAndCompare(&b, {std::move(lhs_literal), std::move(rhs_literal)});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
 | 
			
		||||
@ -321,16 +320,16 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
 | 
			
		||||
    b_values.push_back(2 * i / static_cast<float>(count + 2));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({a_values});
 | 
			
		||||
  Literal a_literal = LiteralUtil::CreateR1<float>({a_values});
 | 
			
		||||
  std::unique_ptr<GlobalData> a_data =
 | 
			
		||||
      client_->TransferToServer(*a_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(a_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto a_constant = ConstantR1<float>(&builder, a_values);
 | 
			
		||||
  auto a_param = Parameter(&builder, 0, a_literal->shape(), "a_param");
 | 
			
		||||
  auto a_param = Parameter(&builder, 0, a_literal.shape(), "a_param");
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR1<float>({b_values});
 | 
			
		||||
  Literal b_literal = LiteralUtil::CreateR1<float>({b_values});
 | 
			
		||||
  std::unique_ptr<GlobalData> b_data =
 | 
			
		||||
      client_->TransferToServer(*b_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto b_constant = Parameter(&builder, 1, a_literal->shape(), "b_param");
 | 
			
		||||
      client_->TransferToServer(b_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto b_constant = Parameter(&builder, 1, a_literal.shape(), "b_param");
 | 
			
		||||
  auto b_param = ConstantR1<float>(&builder, b_values);
 | 
			
		||||
 | 
			
		||||
  auto sum1 = Add(a_constant, b_constant);
 | 
			
		||||
@ -1422,12 +1421,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
 | 
			
		||||
  std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
 | 
			
		||||
  std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> param_literal = LiteralUtil::CreateR1<float>(values);
 | 
			
		||||
  Literal param_literal = LiteralUtil::CreateR1<float>(values);
 | 
			
		||||
  std::unique_ptr<GlobalData> param_data =
 | 
			
		||||
      client_->TransferToServer(*param_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(param_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  auto sum = ConstantR0<float>(&b, 0.0f);
 | 
			
		||||
  auto param = Parameter(&b, 0, param_literal->shape(), "param");
 | 
			
		||||
  auto param = Parameter(&b, 0, param_literal.shape(), "param");
 | 
			
		||||
  for (float exponent : exponents) {
 | 
			
		||||
    sum = Add(sum, Pow(param, ConstantR0<float>(&b, exponent)));
 | 
			
		||||
  }
 | 
			
		||||
@ -1450,14 +1449,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
 | 
			
		||||
  std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
 | 
			
		||||
  std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
 | 
			
		||||
  Literal literal0 = LiteralUtil::CreateR1<float>(values0);
 | 
			
		||||
  std::unique_ptr<GlobalData> data0 =
 | 
			
		||||
      client_->TransferToServer(*literal0).ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
 | 
			
		||||
      client_->TransferToServer(literal0).ConsumeValueOrDie();
 | 
			
		||||
  Literal literal1 = LiteralUtil::CreateR1<float>(values1);
 | 
			
		||||
  std::unique_ptr<GlobalData> data1 =
 | 
			
		||||
      client_->TransferToServer(*literal1).ConsumeValueOrDie();
 | 
			
		||||
  auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
 | 
			
		||||
  auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
 | 
			
		||||
      client_->TransferToServer(literal1).ConsumeValueOrDie();
 | 
			
		||||
  auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
 | 
			
		||||
  auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
 | 
			
		||||
  Pow(Exp(param0), param1);
 | 
			
		||||
 | 
			
		||||
  std::vector<float> expected(values0.size());
 | 
			
		||||
@ -1475,14 +1474,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
 | 
			
		||||
  std::vector<float> values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f};
 | 
			
		||||
  std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
 | 
			
		||||
  Literal literal0 = LiteralUtil::CreateR1<float>(values0);
 | 
			
		||||
  std::unique_ptr<GlobalData> data0 =
 | 
			
		||||
      client_->TransferToServer(*literal0).ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
 | 
			
		||||
      client_->TransferToServer(literal0).ConsumeValueOrDie();
 | 
			
		||||
  Literal literal1 = LiteralUtil::CreateR1<float>(values1);
 | 
			
		||||
  std::unique_ptr<GlobalData> data1 =
 | 
			
		||||
      client_->TransferToServer(*literal1).ConsumeValueOrDie();
 | 
			
		||||
  auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
 | 
			
		||||
  auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
 | 
			
		||||
      client_->TransferToServer(literal1).ConsumeValueOrDie();
 | 
			
		||||
  auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
 | 
			
		||||
  auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
 | 
			
		||||
  Log(Pow(param0, param1));
 | 
			
		||||
 | 
			
		||||
  std::vector<float> expected(values0.size());
 | 
			
		||||
@ -1500,14 +1499,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
 | 
			
		||||
  std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
 | 
			
		||||
  std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
 | 
			
		||||
  Literal literal0 = LiteralUtil::CreateR1<float>(values0);
 | 
			
		||||
  std::unique_ptr<GlobalData> data0 =
 | 
			
		||||
      client_->TransferToServer(*literal0).ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
 | 
			
		||||
      client_->TransferToServer(literal0).ConsumeValueOrDie();
 | 
			
		||||
  Literal literal1 = LiteralUtil::CreateR1<float>(values1);
 | 
			
		||||
  std::unique_ptr<GlobalData> data1 =
 | 
			
		||||
      client_->TransferToServer(*literal1).ConsumeValueOrDie();
 | 
			
		||||
  auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
 | 
			
		||||
  auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
 | 
			
		||||
      client_->TransferToServer(literal1).ConsumeValueOrDie();
 | 
			
		||||
  auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
 | 
			
		||||
  auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
 | 
			
		||||
  Mul(Exp(param0), Exp(param1));
 | 
			
		||||
 | 
			
		||||
  std::vector<float> expected(values0.size());
 | 
			
		||||
@ -1525,14 +1524,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
 | 
			
		||||
  std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
 | 
			
		||||
  std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
 | 
			
		||||
  Literal literal0 = LiteralUtil::CreateR1<float>(values0);
 | 
			
		||||
  std::unique_ptr<GlobalData> data0 =
 | 
			
		||||
      client_->TransferToServer(*literal0).ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
 | 
			
		||||
      client_->TransferToServer(literal0).ConsumeValueOrDie();
 | 
			
		||||
  Literal literal1 = LiteralUtil::CreateR1<float>(values1);
 | 
			
		||||
  std::unique_ptr<GlobalData> data1 =
 | 
			
		||||
      client_->TransferToServer(*literal1).ConsumeValueOrDie();
 | 
			
		||||
  auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
 | 
			
		||||
  auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
 | 
			
		||||
      client_->TransferToServer(literal1).ConsumeValueOrDie();
 | 
			
		||||
  auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
 | 
			
		||||
  auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
 | 
			
		||||
  Div(param0, Exp(param1));
 | 
			
		||||
 | 
			
		||||
  std::vector<float> expected(values0.size());
 | 
			
		||||
@ -1551,20 +1550,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
 | 
			
		||||
  std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
 | 
			
		||||
  std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
 | 
			
		||||
  Literal literal0 = LiteralUtil::CreateR1<float>(values0);
 | 
			
		||||
  std::unique_ptr<GlobalData> data0 =
 | 
			
		||||
      client_->TransferToServer(*literal0).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(literal0).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
 | 
			
		||||
  Literal literal1 = LiteralUtil::CreateR1<float>(values1);
 | 
			
		||||
  std::unique_ptr<GlobalData> data1 =
 | 
			
		||||
      client_->TransferToServer(*literal1).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(literal1).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
 | 
			
		||||
  Literal literal2 = LiteralUtil::CreateR1<float>(values2);
 | 
			
		||||
  std::unique_ptr<GlobalData> data2 =
 | 
			
		||||
      client_->TransferToServer(*literal2).ConsumeValueOrDie();
 | 
			
		||||
  auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
 | 
			
		||||
  auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
 | 
			
		||||
  auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
 | 
			
		||||
      client_->TransferToServer(literal2).ConsumeValueOrDie();
 | 
			
		||||
  auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
 | 
			
		||||
  auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
 | 
			
		||||
  auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
 | 
			
		||||
  Div(Div(param0, param1), param2);
 | 
			
		||||
 | 
			
		||||
  std::vector<float> expected(values0.size());
 | 
			
		||||
@ -1583,21 +1582,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
 | 
			
		||||
  std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
 | 
			
		||||
  std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
 | 
			
		||||
  Literal literal0 = LiteralUtil::CreateR1<float>(values0);
 | 
			
		||||
  std::unique_ptr<GlobalData> data0 =
 | 
			
		||||
      client_->TransferToServer(*literal0).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(literal0).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
 | 
			
		||||
  Literal literal1 = LiteralUtil::CreateR1<float>(values1);
 | 
			
		||||
  std::unique_ptr<GlobalData> data1 =
 | 
			
		||||
      client_->TransferToServer(*literal1).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(literal1).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
 | 
			
		||||
  Literal literal2 = LiteralUtil::CreateR1<float>(values2);
 | 
			
		||||
  std::unique_ptr<GlobalData> data2 =
 | 
			
		||||
      client_->TransferToServer(*literal2).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(literal2).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
 | 
			
		||||
  auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
 | 
			
		||||
  auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
 | 
			
		||||
  auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
 | 
			
		||||
  auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
 | 
			
		||||
  auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
 | 
			
		||||
  Div(param0, Div(param1, param2));
 | 
			
		||||
 | 
			
		||||
  std::vector<float> expected(values0.size());
 | 
			
		||||
@ -1616,21 +1615,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
 | 
			
		||||
  std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f};
 | 
			
		||||
  std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f};
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
 | 
			
		||||
  Literal literal0 = LiteralUtil::CreateR1<float>(values0);
 | 
			
		||||
  std::unique_ptr<GlobalData> data0 =
 | 
			
		||||
      client_->TransferToServer(*literal0).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(literal0).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
 | 
			
		||||
  Literal literal1 = LiteralUtil::CreateR1<float>(values1);
 | 
			
		||||
  std::unique_ptr<GlobalData> data1 =
 | 
			
		||||
      client_->TransferToServer(*literal1).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(literal1).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
 | 
			
		||||
  Literal literal2 = LiteralUtil::CreateR1<float>(values2);
 | 
			
		||||
  std::unique_ptr<GlobalData> data2 =
 | 
			
		||||
      client_->TransferToServer(*literal2).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(literal2).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
 | 
			
		||||
  auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
 | 
			
		||||
  auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
 | 
			
		||||
  auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
 | 
			
		||||
  auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
 | 
			
		||||
  auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
 | 
			
		||||
  Div(param0, Pow(param1, param2));
 | 
			
		||||
 | 
			
		||||
  std::vector<float> expected(values0.size());
 | 
			
		||||
@ -1650,26 +1649,26 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) {
 | 
			
		||||
  std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
 | 
			
		||||
  std::vector<float> values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f};
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
 | 
			
		||||
  Literal literal0 = LiteralUtil::CreateR1<float>(values0);
 | 
			
		||||
  std::unique_ptr<GlobalData> data0 =
 | 
			
		||||
      client_->TransferToServer(*literal0).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(literal0).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
 | 
			
		||||
  Literal literal1 = LiteralUtil::CreateR1<float>(values1);
 | 
			
		||||
  std::unique_ptr<GlobalData> data1 =
 | 
			
		||||
      client_->TransferToServer(*literal1).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(literal1).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
 | 
			
		||||
  Literal literal2 = LiteralUtil::CreateR1<float>(values2);
 | 
			
		||||
  std::unique_ptr<GlobalData> data2 =
 | 
			
		||||
      client_->TransferToServer(*literal2).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(literal2).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal3 = LiteralUtil::CreateR1<float>(values3);
 | 
			
		||||
  Literal literal3 = LiteralUtil::CreateR1<float>(values3);
 | 
			
		||||
  std::unique_ptr<GlobalData> data3 =
 | 
			
		||||
      client_->TransferToServer(*literal3).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(literal3).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
 | 
			
		||||
  auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
 | 
			
		||||
  auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
 | 
			
		||||
  auto param3 = Parameter(&b, 3, literal3->shape(), "param2");
 | 
			
		||||
  auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
 | 
			
		||||
  auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
 | 
			
		||||
  auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
 | 
			
		||||
  auto param3 = Parameter(&b, 3, literal3.shape(), "param2");
 | 
			
		||||
  Div(Div(param0, param1), Div(param2, param3));
 | 
			
		||||
 | 
			
		||||
  std::vector<float> expected(values0.size());
 | 
			
		||||
@ -2096,18 +2095,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) {
 | 
			
		||||
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> param0_literal =
 | 
			
		||||
  Literal param0_literal =
 | 
			
		||||
      LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
 | 
			
		||||
  std::unique_ptr<GlobalData> param0_data =
 | 
			
		||||
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> param1_literal =
 | 
			
		||||
  Literal param1_literal =
 | 
			
		||||
      LiteralUtil::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
 | 
			
		||||
  std::unique_ptr<GlobalData> param1_data =
 | 
			
		||||
      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(param1_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
 | 
			
		||||
  auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
 | 
			
		||||
  auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
 | 
			
		||||
  auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
 | 
			
		||||
  Add(p0, p1);
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f},
 | 
			
		||||
@ -2118,18 +2117,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
 | 
			
		||||
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> param0_literal =
 | 
			
		||||
  Literal param0_literal =
 | 
			
		||||
      LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
 | 
			
		||||
  std::unique_ptr<GlobalData> param0_data =
 | 
			
		||||
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> param1_literal =
 | 
			
		||||
  Literal param1_literal =
 | 
			
		||||
      LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
 | 
			
		||||
  std::unique_ptr<GlobalData> param1_data =
 | 
			
		||||
      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(param1_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
 | 
			
		||||
  auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
 | 
			
		||||
  auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
 | 
			
		||||
  auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
 | 
			
		||||
  Add(p0, p1);
 | 
			
		||||
 | 
			
		||||
  Array3D<float> expected(0, 7, 0);
 | 
			
		||||
@ -2140,13 +2139,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
 | 
			
		||||
XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> param0_literal =
 | 
			
		||||
  Literal param0_literal =
 | 
			
		||||
      LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
 | 
			
		||||
  std::unique_ptr<GlobalData> param0_data =
 | 
			
		||||
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
 | 
			
		||||
  auto p = Parameter(&builder, 0, param0_literal->shape(), "param0");
 | 
			
		||||
  auto p = Parameter(&builder, 0, param0_literal.shape(), "param0");
 | 
			
		||||
  Add(a, p);
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f},
 | 
			
		||||
@ -2206,9 +2205,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
 | 
			
		||||
       0.08,  -1.24, -0.92, 0.49,  1.17,  -0.45, -1.31, -1.44, -0.13, -1.31,
 | 
			
		||||
       -0.79, 1.41,  1.21,  1.05});
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(auto input_data,
 | 
			
		||||
                          client_->TransferToServer(*input_literal));
 | 
			
		||||
                          client_->TransferToServer(input_literal));
 | 
			
		||||
 | 
			
		||||
  auto input = Parameter(&builder, 0, input_literal->shape(), "input");
 | 
			
		||||
  auto input = Parameter(&builder, 0, input_literal.shape(), "input");
 | 
			
		||||
  Tanh(input);
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareR1<float>(
 | 
			
		||||
@ -2239,7 +2238,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
 | 
			
		||||
 | 
			
		||||
  // Just to help make sense of the scales here -- exp(89) saturates float32 and
 | 
			
		||||
  // exp(-10) is smaller than our error spec.
 | 
			
		||||
  std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1<float>(
 | 
			
		||||
  Literal input_literal = LiteralUtil::CreateR1<float>(
 | 
			
		||||
      {1.02,   -0.32,  0.85,   0.9,    1.23,   -0.91,  -0.49, 0.8,    -1.31,
 | 
			
		||||
       -1.44,  -0.13,  -1.31,  -0.79,  1.41,   1.21,   1.05,  -195.6, -194.5,
 | 
			
		||||
       -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5,  -17.4,
 | 
			
		||||
@ -2252,16 +2251,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
 | 
			
		||||
       78.3,   79.4,   80.5,   81.6,   82.7,   83.8,   84.9,  85.2,   86.3,
 | 
			
		||||
       86.4,   86.5,   87.6,   87.7,   87.8,   87.9});
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
 | 
			
		||||
                          client_->TransferToServer(*input_literal));
 | 
			
		||||
                          client_->TransferToServer(input_literal));
 | 
			
		||||
 | 
			
		||||
  auto input = Parameter(&builder, 0, input_literal->shape(), "input");
 | 
			
		||||
  auto input = Parameter(&builder, 0, input_literal.shape(), "input");
 | 
			
		||||
  Exp(input);
 | 
			
		||||
 | 
			
		||||
  std::vector<float> expected_result;
 | 
			
		||||
  int64 input_size = input_literal->shape().dimensions(0);
 | 
			
		||||
  int64 input_size = input_literal.shape().dimensions(0);
 | 
			
		||||
  expected_result.reserve(input_size);
 | 
			
		||||
  for (int64 i = 0; i < input_size; i++) {
 | 
			
		||||
    expected_result.push_back(std::exp(input_literal->Get<float>({i})));
 | 
			
		||||
    expected_result.push_back(std::exp(input_literal.Get<float>({i})));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
 | 
			
		||||
@ -2273,7 +2272,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
 | 
			
		||||
  // implementation on XLA CPU.
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1<float>(
 | 
			
		||||
  Literal input_literal = LiteralUtil::CreateR1<float>(
 | 
			
		||||
      {-1.29,    -1.41,    -1.25,    -13.5,    -11.7,    -17.9,    -198,
 | 
			
		||||
       -167,     1.29,     1.41,     1.25,     13.5,     11.7,     17.9,
 | 
			
		||||
       198,      167,      1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04,  1.84e+04,
 | 
			
		||||
@ -2290,16 +2289,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
 | 
			
		||||
       1.7e+31,  1.44e+31, 1.1e+31,  1.4e+32,  1.67e+32, 1.96e+33, 1.11e+33,
 | 
			
		||||
       1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35});
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
 | 
			
		||||
                          client_->TransferToServer(*input_literal));
 | 
			
		||||
                          client_->TransferToServer(input_literal));
 | 
			
		||||
 | 
			
		||||
  auto input = Parameter(&builder, 0, input_literal->shape(), "input");
 | 
			
		||||
  auto input = Parameter(&builder, 0, input_literal.shape(), "input");
 | 
			
		||||
  Log(input);
 | 
			
		||||
 | 
			
		||||
  std::vector<float> expected_result;
 | 
			
		||||
  int64 input_size = input_literal->shape().dimensions(0);
 | 
			
		||||
  int64 input_size = input_literal.shape().dimensions(0);
 | 
			
		||||
  expected_result.reserve(input_size);
 | 
			
		||||
  for (int64 i = 0; i < input_size; i++) {
 | 
			
		||||
    expected_result.push_back(std::log(input_literal->Get<float>({i})));
 | 
			
		||||
    expected_result.push_back(std::log(input_literal.Get<float>({i})));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
 | 
			
		||||
@ -2465,10 +2464,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) {
 | 
			
		||||
  auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0});
 | 
			
		||||
  Tuple(&builder, {cmp_dim_0, cmp_dim_1});
 | 
			
		||||
 | 
			
		||||
  auto expected = LiteralUtil::MakeTuple(
 | 
			
		||||
      {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}).get(),
 | 
			
		||||
       LiteralUtil::CreateR2<bool>({{true, false}, {false, false}}).get()});
 | 
			
		||||
  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
 | 
			
		||||
  auto expected = LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
      {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}),
 | 
			
		||||
       LiteralUtil::CreateR2<bool>({{true, false}, {false, false}})});
 | 
			
		||||
  ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
 | 
			
		||||
@ -2821,10 +2820,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
 | 
			
		||||
  std::iota(r1.begin(), r1.end(), 1.0);
 | 
			
		||||
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
  std::unique_ptr<Literal> a_literal =
 | 
			
		||||
      LiteralUtil::CreateR4FromArray4DWithLayout(
 | 
			
		||||
          r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
 | 
			
		||||
  auto a = ConstantLiteral(&builder, *a_literal);
 | 
			
		||||
  Literal a_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
 | 
			
		||||
      r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
 | 
			
		||||
  auto a = ConstantLiteral(&builder, a_literal);
 | 
			
		||||
  auto b = ConstantR1<float>(&builder, r1);
 | 
			
		||||
  Add(a, b, {1});
 | 
			
		||||
 | 
			
		||||
@ -2886,11 +2884,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
  auto x_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
 | 
			
		||||
  auto y_literal = LiteralUtil::CreateR1<float>({4, 5});
 | 
			
		||||
  auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  auto x = Parameter(&builder, 0, x_literal->shape(), "x");
 | 
			
		||||
  auto y = Parameter(&builder, 1, y_literal->shape(), "y");
 | 
			
		||||
  auto x = Parameter(&builder, 0, x_literal.shape(), "x");
 | 
			
		||||
  auto y = Parameter(&builder, 1, y_literal.shape(), "y");
 | 
			
		||||
  auto slice = Slice(x, {1}, {2}, {1});
 | 
			
		||||
  Sub(slice, y);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -63,7 +63,7 @@ class BatchNormalizationTest
 | 
			
		||||
        {5.0f, 4.4f},   // p2
 | 
			
		||||
    });
 | 
			
		||||
    input_array_.FillWithPZ(pz);
 | 
			
		||||
    input_literal_ = std::move(*LiteralUtil::CreateR4FromArray4D(input_array_));
 | 
			
		||||
    input_literal_ = LiteralUtil::CreateR4FromArray4D(input_array_);
 | 
			
		||||
    CHECK_EQ(kSamples, input_array_.planes());
 | 
			
		||||
    CHECK_EQ(kZ, input_array_.depth());
 | 
			
		||||
    CHECK_EQ(kY, input_array_.height());
 | 
			
		||||
@ -242,14 +242,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) {
 | 
			
		||||
  BatchNormTraining(operand, scale, offset,
 | 
			
		||||
                    /*epsilon=*/0.001, kFeatureIndex);
 | 
			
		||||
 | 
			
		||||
  auto expected = LiteralUtil::MakeTuple(
 | 
			
		||||
  auto expected = LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
      {LiteralUtil::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
 | 
			
		||||
                                     {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
 | 
			
		||||
           .get(),
 | 
			
		||||
       LiteralUtil::CreateR1<float>({4, 5}).get(),
 | 
			
		||||
       LiteralUtil::CreateR1<float>({5, 5}).get()});
 | 
			
		||||
                                     {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}),
 | 
			
		||||
       LiteralUtil::CreateR1<float>({4, 5}),
 | 
			
		||||
       LiteralUtil::CreateR1<float>({5, 5})});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
 | 
			
		||||
  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
 | 
			
		||||
@ -267,14 +266,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
 | 
			
		||||
  BatchNormTraining(operand, scale, offset,
 | 
			
		||||
                    /*epsilon=*/0.001, kFeatureIndex);
 | 
			
		||||
 | 
			
		||||
  auto expected = LiteralUtil::MakeTuple(
 | 
			
		||||
  auto expected = LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
      {LiteralUtil::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
 | 
			
		||||
                                     {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
 | 
			
		||||
           .get(),
 | 
			
		||||
       LiteralUtil::CreateR1<float>({4, 5}).get(),
 | 
			
		||||
       LiteralUtil::CreateR1<float>({5, 5}).get()});
 | 
			
		||||
                                     {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}),
 | 
			
		||||
       LiteralUtil::CreateR1<float>({4, 5}),
 | 
			
		||||
       LiteralUtil::CreateR1<float>({5, 5})});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
 | 
			
		||||
  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
 | 
			
		||||
@ -298,13 +296,12 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
 | 
			
		||||
  BatchNormTraining(h0, h1, h2,
 | 
			
		||||
                    /*epsilon=*/1, kFeatureIndex);
 | 
			
		||||
 | 
			
		||||
  auto expected = LiteralUtil::MakeTuple(
 | 
			
		||||
      {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
 | 
			
		||||
           .get(),
 | 
			
		||||
       LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
 | 
			
		||||
       LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
 | 
			
		||||
  auto expected = LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
      {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f)),
 | 
			
		||||
       LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)),
 | 
			
		||||
       LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f))});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareTuple(&builder, *expected,
 | 
			
		||||
  ComputeAndCompareTuple(&builder, expected,
 | 
			
		||||
                         {operand.get(), scale.get(), offset.get()},
 | 
			
		||||
                         ErrorSpec(0.1));
 | 
			
		||||
}
 | 
			
		||||
@ -331,14 +328,13 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) {
 | 
			
		||||
  BatchNormTraining(h0, h1, h2,
 | 
			
		||||
                    /*epsilon=*/-100, kFeatureIndex);
 | 
			
		||||
 | 
			
		||||
  auto expected = LiteralUtil::MakeTuple(
 | 
			
		||||
  auto expected = LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
      {LiteralUtil::CreateR3FromArray3D<float>(
 | 
			
		||||
           {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
 | 
			
		||||
           .get(),
 | 
			
		||||
       LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
 | 
			
		||||
       LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
 | 
			
		||||
           {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}),
 | 
			
		||||
       LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)),
 | 
			
		||||
       LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f))});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareTuple(&builder, *expected,
 | 
			
		||||
  ComputeAndCompareTuple(&builder, expected,
 | 
			
		||||
                         {operand.get(), scale.get(), offset.get()},
 | 
			
		||||
                         ErrorSpec(0.1));
 | 
			
		||||
}
 | 
			
		||||
@ -363,14 +359,13 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) {
 | 
			
		||||
  BatchNormGrad(operand, scale, mean, var, grad_output,
 | 
			
		||||
                /*epsilon=*/0.0, kFeatureIndex);
 | 
			
		||||
 | 
			
		||||
  auto expected = LiteralUtil::MakeTuple(
 | 
			
		||||
  auto expected = LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
      {LiteralUtil::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
 | 
			
		||||
                                     {{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
 | 
			
		||||
           .get(),
 | 
			
		||||
       LiteralUtil::CreateR1<float>({0, 0}).get(),
 | 
			
		||||
       LiteralUtil::CreateR1<float>({16, 20}).get()});
 | 
			
		||||
                                     {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}),
 | 
			
		||||
       LiteralUtil::CreateR1<float>({0, 0}),
 | 
			
		||||
       LiteralUtil::CreateR1<float>({16, 20})});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
 | 
			
		||||
  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct BatchNormTestParam {
 | 
			
		||||
@ -522,22 +517,22 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
 | 
			
		||||
  auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
 | 
			
		||||
 | 
			
		||||
  auto input_activations =
 | 
			
		||||
      Parameter(&builder, 0, input_literal->shape(), "input");
 | 
			
		||||
      Parameter(&builder, 0, input_literal.shape(), "input");
 | 
			
		||||
  auto scale_activations =
 | 
			
		||||
      Parameter(&builder, 1, scale_literal->shape(), "offset");
 | 
			
		||||
      Parameter(&builder, 1, scale_literal.shape(), "offset");
 | 
			
		||||
  auto offset_activations =
 | 
			
		||||
      Parameter(&builder, 2, offset_literal->shape(), "scale");
 | 
			
		||||
      Parameter(&builder, 2, offset_literal.shape(), "scale");
 | 
			
		||||
 | 
			
		||||
  auto expected = LiteralUtil::MakeTuple(
 | 
			
		||||
      {expected_normalized.get(), LiteralUtil::CreateR1<float>(mean).get(),
 | 
			
		||||
       LiteralUtil::CreateR1<float>(var).get()});
 | 
			
		||||
  auto expected = LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
      {expected_normalized, LiteralUtil::CreateR1<float>(mean),
 | 
			
		||||
       LiteralUtil::CreateR1<float>(var)});
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<GlobalData> input_data =
 | 
			
		||||
      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(input_literal).ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<GlobalData> scale_data =
 | 
			
		||||
      client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(scale_literal).ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<GlobalData> offset_data =
 | 
			
		||||
      client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(offset_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  BatchNormTraining(input_activations, scale_activations, offset_activations,
 | 
			
		||||
                    epsilon, feature_index);
 | 
			
		||||
@ -547,7 +542,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
 | 
			
		||||
  // testcase.
 | 
			
		||||
  execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
 | 
			
		||||
  ComputeAndCompareTuple(
 | 
			
		||||
      &builder, *expected,
 | 
			
		||||
      &builder, expected,
 | 
			
		||||
      {input_data.get(), scale_data.get(), offset_data.get()},
 | 
			
		||||
      ErrorSpec(0.01, 1));
 | 
			
		||||
}
 | 
			
		||||
@ -622,27 +617,27 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) {
 | 
			
		||||
  auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
 | 
			
		||||
 | 
			
		||||
  auto input_activations =
 | 
			
		||||
      Parameter(&builder, 0, input_literal->shape(), "input");
 | 
			
		||||
      Parameter(&builder, 0, input_literal.shape(), "input");
 | 
			
		||||
  auto scale_activations =
 | 
			
		||||
      Parameter(&builder, 1, scale_literal->shape(), "offset");
 | 
			
		||||
      Parameter(&builder, 1, scale_literal.shape(), "offset");
 | 
			
		||||
  auto offset_activations =
 | 
			
		||||
      Parameter(&builder, 2, offset_literal->shape(), "scale");
 | 
			
		||||
  auto mean_activations = Parameter(&builder, 3, mean_literal->shape(), "mean");
 | 
			
		||||
      Parameter(&builder, 2, offset_literal.shape(), "scale");
 | 
			
		||||
  auto mean_activations = Parameter(&builder, 3, mean_literal.shape(), "mean");
 | 
			
		||||
  auto variance_activations =
 | 
			
		||||
      Parameter(&builder, 4, var_literal->shape(), "variance");
 | 
			
		||||
      Parameter(&builder, 4, var_literal.shape(), "variance");
 | 
			
		||||
 | 
			
		||||
  Array4D<float> expected = normalized;
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<GlobalData> input_data =
 | 
			
		||||
      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(input_literal).ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<GlobalData> scale_data =
 | 
			
		||||
      client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(scale_literal).ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<GlobalData> offset_data =
 | 
			
		||||
      client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(offset_literal).ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<GlobalData> mean_data =
 | 
			
		||||
      client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(mean_literal).ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<GlobalData> variance_data =
 | 
			
		||||
      client_->TransferToServer(*var_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(var_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  BatchNormInference(input_activations, scale_activations, offset_activations,
 | 
			
		||||
                     mean_activations, variance_activations, epsilon,
 | 
			
		||||
@ -811,40 +806,37 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
 | 
			
		||||
  auto grad_output_literal =
 | 
			
		||||
      LiteralUtil::CreateR4FromArray4D<float>(grad_output_array);
 | 
			
		||||
 | 
			
		||||
  auto input_parameter =
 | 
			
		||||
      Parameter(&builder, 0, input_literal->shape(), "input");
 | 
			
		||||
  auto scale_parameter =
 | 
			
		||||
      Parameter(&builder, 1, scale_literal->shape(), "scale");
 | 
			
		||||
  auto mean_parameter = Parameter(&builder, 2, mean_literal->shape(), "mean");
 | 
			
		||||
  auto var_parameter = Parameter(&builder, 3, var_literal->shape(), "variance");
 | 
			
		||||
  auto input_parameter = Parameter(&builder, 0, input_literal.shape(), "input");
 | 
			
		||||
  auto scale_parameter = Parameter(&builder, 1, scale_literal.shape(), "scale");
 | 
			
		||||
  auto mean_parameter = Parameter(&builder, 2, mean_literal.shape(), "mean");
 | 
			
		||||
  auto var_parameter = Parameter(&builder, 3, var_literal.shape(), "variance");
 | 
			
		||||
  auto grad_output_parameter =
 | 
			
		||||
      Parameter(&builder, 4, grad_output_literal->shape(), "grad_output");
 | 
			
		||||
      Parameter(&builder, 4, grad_output_literal.shape(), "grad_output");
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<GlobalData> input_data =
 | 
			
		||||
      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(input_literal).ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<GlobalData> scale_data =
 | 
			
		||||
      client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(scale_literal).ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<GlobalData> mean_data =
 | 
			
		||||
      client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(mean_literal).ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<GlobalData> var_data =
 | 
			
		||||
      client_->TransferToServer(*var_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(var_literal).ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<GlobalData> grad_output_data =
 | 
			
		||||
      client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(grad_output_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter,
 | 
			
		||||
                grad_output_parameter, epsilon, feature_index);
 | 
			
		||||
 | 
			
		||||
  auto expected =
 | 
			
		||||
      LiteralUtil::MakeTuple({expected_grad_activation.get(),
 | 
			
		||||
                              LiteralUtil::CreateR1<float>(grad_scale).get(),
 | 
			
		||||
                              LiteralUtil::CreateR1<float>(grad_offset).get()});
 | 
			
		||||
  auto expected = LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
      {expected_grad_activation, LiteralUtil::CreateR1<float>(grad_scale),
 | 
			
		||||
       LiteralUtil::CreateR1<float>(grad_offset)});
 | 
			
		||||
 | 
			
		||||
  // Run all HLO passes during this test.  In particular, ClientLibraryTestBase
 | 
			
		||||
  // disables constant folding, but we want it enabled for our zero-sized tensor
 | 
			
		||||
  // testcase.
 | 
			
		||||
  execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareTuple(&builder, *expected,
 | 
			
		||||
  ComputeAndCompareTuple(&builder, expected,
 | 
			
		||||
                         {input_data.get(), scale_data.get(), mean_data.get(),
 | 
			
		||||
                          var_data.get(), grad_output_data.get()},
 | 
			
		||||
                         ErrorSpec(0.01, 1));
 | 
			
		||||
 | 
			
		||||
@ -95,22 +95,19 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
 | 
			
		||||
 | 
			
		||||
  BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex);
 | 
			
		||||
 | 
			
		||||
  auto expected = LiteralUtil::MakeTuple(
 | 
			
		||||
  auto expected = LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
      {LiteralUtil::CreateR4<bfloat16>(
 | 
			
		||||
           {{{{static_cast<bfloat16>(-1.6875f)},
 | 
			
		||||
              {static_cast<bfloat16>(-2.04f)}},
 | 
			
		||||
             {{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.66f)}}},
 | 
			
		||||
            {{{static_cast<bfloat16>(1.89f)}, {static_cast<bfloat16>(3.35f)}},
 | 
			
		||||
             {{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}})
 | 
			
		||||
           .get(),
 | 
			
		||||
             {{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}}),
 | 
			
		||||
       LiteralUtil::CreateR1<bfloat16>(
 | 
			
		||||
           {static_cast<bfloat16>(4), static_cast<bfloat16>(5)})
 | 
			
		||||
           .get(),
 | 
			
		||||
           {static_cast<bfloat16>(4), static_cast<bfloat16>(5)}),
 | 
			
		||||
       LiteralUtil::CreateR1<bfloat16>(
 | 
			
		||||
           {static_cast<bfloat16>(5), static_cast<bfloat16>(5)})
 | 
			
		||||
           .get()});
 | 
			
		||||
           {static_cast<bfloat16>(5), static_cast<bfloat16>(5)})});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01, 0.02));
 | 
			
		||||
  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
 | 
			
		||||
@ -139,21 +136,18 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
 | 
			
		||||
  BatchNormGrad(operand, scale, mean, var, grad_output,
 | 
			
		||||
                /*epsilon=*/0.0, kFeatureIndex);
 | 
			
		||||
 | 
			
		||||
  auto expected = LiteralUtil::MakeTuple(
 | 
			
		||||
  auto expected = LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
      {LiteralUtil::CreateR4<bfloat16>(
 | 
			
		||||
           {{{{static_cast<bfloat16>(-3.f)}, {static_cast<bfloat16>(-3.f)}},
 | 
			
		||||
             {{static_cast<bfloat16>(-1.f)}, {static_cast<bfloat16>(-1.f)}}},
 | 
			
		||||
            {{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(1.f)}},
 | 
			
		||||
             {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}})
 | 
			
		||||
           .get(),
 | 
			
		||||
             {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}}),
 | 
			
		||||
       LiteralUtil::CreateR1<bfloat16>(
 | 
			
		||||
           {static_cast<bfloat16>(0), static_cast<bfloat16>(0)})
 | 
			
		||||
           .get(),
 | 
			
		||||
           {static_cast<bfloat16>(0), static_cast<bfloat16>(0)}),
 | 
			
		||||
       LiteralUtil::CreateR1<bfloat16>(
 | 
			
		||||
           {static_cast<bfloat16>(16), static_cast<bfloat16>(20)})
 | 
			
		||||
           .get()});
 | 
			
		||||
           {static_cast<bfloat16>(16), static_cast<bfloat16>(20)})});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01));
 | 
			
		||||
  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
@ -60,10 +60,10 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
 | 
			
		||||
                                         float end, int seed) {
 | 
			
		||||
    *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
 | 
			
		||||
    r3_array->FillRandom(start, end, seed);
 | 
			
		||||
    auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout(
 | 
			
		||||
    auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array).Relayout(
 | 
			
		||||
        LayoutUtil::MakeLayout(minor_to_major));
 | 
			
		||||
    std::unique_ptr<GlobalData> r3_global_data =
 | 
			
		||||
        client_->TransferToServer(*r3_data).ConsumeValueOrDie();
 | 
			
		||||
        client_->TransferToServer(r3_data).ConsumeValueOrDie();
 | 
			
		||||
    return r3_global_data;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -74,10 +74,10 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
 | 
			
		||||
                                         float end, int seed) {
 | 
			
		||||
    *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
 | 
			
		||||
    r2_array->FillRandom(start, end, seed);
 | 
			
		||||
    auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout(
 | 
			
		||||
    auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array).Relayout(
 | 
			
		||||
        LayoutUtil::MakeLayout(minor_to_major));
 | 
			
		||||
    std::unique_ptr<GlobalData> r2_global_data =
 | 
			
		||||
        client_->TransferToServer(*r2_data).ConsumeValueOrDie();
 | 
			
		||||
        client_->TransferToServer(r2_data).ConsumeValueOrDie();
 | 
			
		||||
    return r2_global_data;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -293,7 +293,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
 | 
			
		||||
  XlaBuilder b(TestName());
 | 
			
		||||
 | 
			
		||||
  Add(ConstantR2<float>(&b, {{1.0, 5.0}}),
 | 
			
		||||
      ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
 | 
			
		||||
      ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
 | 
			
		||||
                              {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
 | 
			
		||||
      /*broadcast_dimensions=*/{1, 2});
 | 
			
		||||
 | 
			
		||||
@ -301,7 +301,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
 | 
			
		||||
      LiteralUtil::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
 | 
			
		||||
                                    {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct R3ImplicitBroadcastSpec {
 | 
			
		||||
@ -370,8 +370,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
 | 
			
		||||
  }
 | 
			
		||||
  auto expected = LiteralUtil::CreateR3FromArray3D(expected_array);
 | 
			
		||||
  ComputeAndCompareLiteral(
 | 
			
		||||
      &builder, *expected,
 | 
			
		||||
      {r3_implicit_global_data.get(), r3_global_data.get()},
 | 
			
		||||
      &builder, expected, {r3_implicit_global_data.get(), r3_global_data.get()},
 | 
			
		||||
      ErrorSpec(1e-7, 1e-7));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -395,89 +394,89 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
 | 
			
		||||
  auto expected =
 | 
			
		||||
      LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()},
 | 
			
		||||
  ComputeAndCompareLiteral(&b, expected, {r3.get(), r1.get()},
 | 
			
		||||
                           ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) {
 | 
			
		||||
  XlaBuilder b(TestName());
 | 
			
		||||
  auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}}));
 | 
			
		||||
  auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}}));
 | 
			
		||||
  auto r3 = ConstantLiteral(
 | 
			
		||||
      &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
 | 
			
		||||
      &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
 | 
			
		||||
  Add(r3, r1);
 | 
			
		||||
 | 
			
		||||
  auto expected =
 | 
			
		||||
      LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) {
 | 
			
		||||
  XlaBuilder b(TestName());
 | 
			
		||||
  auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}, {2}}}));
 | 
			
		||||
  auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}, {2}}}));
 | 
			
		||||
  auto r3 = ConstantLiteral(
 | 
			
		||||
      &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
 | 
			
		||||
      &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
 | 
			
		||||
  Add(r3, r1);
 | 
			
		||||
 | 
			
		||||
  auto expected =
 | 
			
		||||
      LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
 | 
			
		||||
  XlaBuilder b(TestName());
 | 
			
		||||
  auto r1 =
 | 
			
		||||
      ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
 | 
			
		||||
      ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
 | 
			
		||||
  auto r3 = ConstantLiteral(
 | 
			
		||||
      &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
 | 
			
		||||
      &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
 | 
			
		||||
  Add(r3, r1);
 | 
			
		||||
 | 
			
		||||
  auto expected =
 | 
			
		||||
      LiteralUtil::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) {
 | 
			
		||||
  XlaBuilder b(TestName());
 | 
			
		||||
  auto r1 =
 | 
			
		||||
      ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
 | 
			
		||||
      ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
 | 
			
		||||
  auto r3 = ConstantLiteral(
 | 
			
		||||
      &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
 | 
			
		||||
      &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
 | 
			
		||||
  Add(r3, r1);
 | 
			
		||||
 | 
			
		||||
  auto expected =
 | 
			
		||||
      LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) {
 | 
			
		||||
  XlaBuilder b(TestName());
 | 
			
		||||
  auto r1 = ConstantLiteral(
 | 
			
		||||
      &b, *LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
 | 
			
		||||
      &b, LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
 | 
			
		||||
  auto r3 = ConstantLiteral(
 | 
			
		||||
      &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
 | 
			
		||||
      &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
 | 
			
		||||
  Add(r3, r1);
 | 
			
		||||
 | 
			
		||||
  auto expected =
 | 
			
		||||
      LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) {
 | 
			
		||||
  XlaBuilder b(TestName());
 | 
			
		||||
  auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}}}));
 | 
			
		||||
  auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}}}));
 | 
			
		||||
  auto r3 = ConstantLiteral(
 | 
			
		||||
      &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
 | 
			
		||||
      &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
 | 
			
		||||
  Add(r3, r1);
 | 
			
		||||
 | 
			
		||||
  auto expected =
 | 
			
		||||
      LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct R2ImplicitBroadcastSpec {
 | 
			
		||||
@ -618,7 +617,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
 | 
			
		||||
 | 
			
		||||
  auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
 | 
			
		||||
  ComputeAndCompareLiteral(
 | 
			
		||||
      &builder, *expected,
 | 
			
		||||
      &builder, expected,
 | 
			
		||||
      {r2_implicit_global_data1.get(), r2_global_data.get(),
 | 
			
		||||
       r2_implicit_global_data2.get()},
 | 
			
		||||
      ErrorSpec(1e-6, 1e-6));
 | 
			
		||||
@ -630,65 +629,63 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances,
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) {
 | 
			
		||||
  XlaBuilder b(TestName());
 | 
			
		||||
  auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}}));
 | 
			
		||||
  auto r2 =
 | 
			
		||||
      ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
 | 
			
		||||
  auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}}));
 | 
			
		||||
  auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
 | 
			
		||||
  Add(r2, r1);
 | 
			
		||||
 | 
			
		||||
  auto expected = LiteralUtil::CreateR2<float>({{2, 4}, {4, 6}});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) {
 | 
			
		||||
  XlaBuilder b(TestName());
 | 
			
		||||
  auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1}, {2}}));
 | 
			
		||||
  auto r2 =
 | 
			
		||||
      ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
 | 
			
		||||
  auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1}, {2}}));
 | 
			
		||||
  auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
 | 
			
		||||
  Add(r2, r1);
 | 
			
		||||
 | 
			
		||||
  auto expected = LiteralUtil::CreateR2<float>({{2, 3}, {5, 6}});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) {
 | 
			
		||||
  XlaBuilder b(TestName());
 | 
			
		||||
  auto r1 = ConstantR1<float>(&b, {10, 20});
 | 
			
		||||
  auto r3 = ConstantLiteral(
 | 
			
		||||
      &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
 | 
			
		||||
      &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
 | 
			
		||||
  Add(r3, r1, {0});
 | 
			
		||||
 | 
			
		||||
  auto expected = LiteralUtil::CreateR3<float>(
 | 
			
		||||
      {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) {
 | 
			
		||||
  XlaBuilder b(TestName());
 | 
			
		||||
  auto r1 = ConstantR1<float>(&b, {10, 20});
 | 
			
		||||
  auto r3 = ConstantLiteral(
 | 
			
		||||
      &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
 | 
			
		||||
      &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
 | 
			
		||||
  Add(r1, r3, {1});
 | 
			
		||||
 | 
			
		||||
  auto expected = LiteralUtil::CreateR3<float>(
 | 
			
		||||
      {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) {
 | 
			
		||||
  XlaBuilder b(TestName());
 | 
			
		||||
  auto r1 = ConstantR1<float>(&b, {10, 20});
 | 
			
		||||
  auto r3 = ConstantLiteral(
 | 
			
		||||
      &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
 | 
			
		||||
      &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
 | 
			
		||||
  Add(r1, r3, {2});
 | 
			
		||||
 | 
			
		||||
  auto expected = LiteralUtil::CreateR3<float>(
 | 
			
		||||
      {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
 | 
			
		||||
@ -697,7 +694,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
 | 
			
		||||
  auto r1_1 = ConstantR1<float>(&b, {100, 200});
 | 
			
		||||
  auto r1_2 = ConstantR1<float>(&b, {10, 20});
 | 
			
		||||
  auto r3 = ConstantLiteral(
 | 
			
		||||
      &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
 | 
			
		||||
      &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
 | 
			
		||||
  for (int i = 0; i < 3; ++i) {
 | 
			
		||||
    r3 = Add(r1_0, r3, {0});
 | 
			
		||||
    r3 = Add(r3, r1_1, {1});
 | 
			
		||||
@ -709,7 +706,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
 | 
			
		||||
      {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}},
 | 
			
		||||
       {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
 | 
			
		||||
@ -730,7 +727,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
 | 
			
		||||
      {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}},
 | 
			
		||||
       {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
 | 
			
		||||
@ -739,7 +736,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
 | 
			
		||||
  XlaBuilder b(TestName());
 | 
			
		||||
 | 
			
		||||
  Add(ConstantR2<float>(&b, {{1.0, 5.0}, {1.0, 5.0}}),
 | 
			
		||||
      ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
 | 
			
		||||
      ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
 | 
			
		||||
                              {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
 | 
			
		||||
      /*broadcast_dimensions=*/{1, 2});
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -46,8 +46,8 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
 | 
			
		||||
  hlo_module->AddEntryComputation(builder.Build());
 | 
			
		||||
  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR0<float>(42.0),
 | 
			
		||||
                                    *result, error_spec_));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR0<float>(42.0), result,
 | 
			
		||||
                                    error_spec_));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
 | 
			
		||||
@ -63,7 +63,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
 | 
			
		||||
  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(
 | 
			
		||||
      *LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result,
 | 
			
		||||
      LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), result,
 | 
			
		||||
      error_spec_));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -86,12 +86,12 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
 | 
			
		||||
  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(
 | 
			
		||||
      *LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
 | 
			
		||||
      LiteralSlice(*result, {0}), error_spec_));
 | 
			
		||||
      LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
 | 
			
		||||
      LiteralSlice(result, {0}), error_spec_));
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(
 | 
			
		||||
      *LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
 | 
			
		||||
      LiteralSlice(*result, {1}), error_spec_));
 | 
			
		||||
      LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
 | 
			
		||||
      LiteralSlice(result, {1}), error_spec_));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
 | 
			
		||||
@ -107,7 +107,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
 | 
			
		||||
  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(
 | 
			
		||||
      *LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), *result,
 | 
			
		||||
      LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), result,
 | 
			
		||||
      error_spec_));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -126,7 +126,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
 | 
			
		||||
  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(
 | 
			
		||||
      *LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), *result,
 | 
			
		||||
      LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), result,
 | 
			
		||||
      error_spec_));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -143,9 +143,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
 | 
			
		||||
  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(
 | 
			
		||||
      *LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
 | 
			
		||||
                                     {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
 | 
			
		||||
      *result, error_spec_));
 | 
			
		||||
      LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
 | 
			
		||||
                                    {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
 | 
			
		||||
      result, error_spec_));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
 | 
			
		||||
@ -166,9 +166,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
 | 
			
		||||
  Array2D<float> pz({{1, 2}, {1, 2}});
 | 
			
		||||
  expected.FillWithPZ(pz);
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(
 | 
			
		||||
      LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
 | 
			
		||||
                            *result, error_spec_));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(
 | 
			
		||||
      LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
 | 
			
		||||
@ -197,9 +196,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
 | 
			
		||||
  }
 | 
			
		||||
  expected.FillWithYX(yx);
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(
 | 
			
		||||
      LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
 | 
			
		||||
                            *result, error_spec_));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(
 | 
			
		||||
      LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
 | 
			
		||||
@ -220,8 +218,8 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
 | 
			
		||||
  hlo_module->AddEntryComputation(builder.Build());
 | 
			
		||||
  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(r4_array),
 | 
			
		||||
                                    *result, error_spec_));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(r4_array),
 | 
			
		||||
                                    result, error_spec_));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
 | 
			
		||||
@ -240,9 +238,8 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
 | 
			
		||||
  Array4D<float> expected(64, 64, 3, 3);
 | 
			
		||||
  expected.Fill(1.0f);
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(
 | 
			
		||||
      LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
 | 
			
		||||
                            *result, error_spec_));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(
 | 
			
		||||
      LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
 | 
			
		||||
@ -263,9 +260,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
 | 
			
		||||
  Array4D<float> expected(3, 3, 2, 2);
 | 
			
		||||
  expected.FillWithYX(to_broadcast);
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(
 | 
			
		||||
      LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
 | 
			
		||||
                            *result, error_spec_));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(
 | 
			
		||||
      LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
 | 
			
		||||
@ -295,9 +291,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
 | 
			
		||||
  hlo_module->AddEntryComputation(builder.Build());
 | 
			
		||||
  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(
 | 
			
		||||
      LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
 | 
			
		||||
                            *result, error_spec_));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(
 | 
			
		||||
      LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
@ -77,8 +77,7 @@ class CallOpTest : public ClientLibraryTestBase {
 | 
			
		||||
XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) {
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
  XlaComputation callee = CreateR0F32IdentityComputation();
 | 
			
		||||
  auto constant =
 | 
			
		||||
      ConstantLiteral(&builder, *LiteralUtil::CreateR0<float>(42.0));
 | 
			
		||||
  auto constant = ConstantLiteral(&builder, LiteralUtil::CreateR0<float>(42.0));
 | 
			
		||||
  Call(&builder, callee, {constant});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareR0<float>(&builder, 42.0, {}, ErrorSpec(0.01f));
 | 
			
		||||
@ -87,8 +86,8 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) {
 | 
			
		||||
XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) {
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
  XlaComputation callee = CreateR1S0F32AdditionComputation();
 | 
			
		||||
  auto x = ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({}));
 | 
			
		||||
  auto y = ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({}));
 | 
			
		||||
  auto x = ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({}));
 | 
			
		||||
  auto y = ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({}));
 | 
			
		||||
  Call(&builder, callee, {x, y});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.01f));
 | 
			
		||||
@ -98,9 +97,9 @@ XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) {
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
  XlaComputation callee = CreateR1S2F32AdditionComputation();
 | 
			
		||||
  auto x =
 | 
			
		||||
      ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({1.0f, 2.0f}));
 | 
			
		||||
      ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({1.0f, 2.0f}));
 | 
			
		||||
  auto y =
 | 
			
		||||
      ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({2.0f, 3.0f}));
 | 
			
		||||
      ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({2.0f, 3.0f}));
 | 
			
		||||
  Call(&builder, callee, {x, y});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareR1<float>(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f));
 | 
			
		||||
@ -133,7 +132,7 @@ XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) {
 | 
			
		||||
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      std::unique_ptr<GlobalData> start,
 | 
			
		||||
      client_->TransferToServer(*LiteralUtil::CreateR0<float>(1.0f)));
 | 
			
		||||
      client_->TransferToServer(LiteralUtil::CreateR0<float>(1.0f)));
 | 
			
		||||
  ComputeAndCompareR0<float>(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -141,10 +140,10 @@ XLA_TEST_F(CallOpTest, CallR0F32Tuple) {
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
  XlaComputation callee = CreateR0F32TupleComputation();
 | 
			
		||||
  auto elem = LiteralUtil::CreateR0<float>(42.0);
 | 
			
		||||
  auto tuple = LiteralUtil::MakeTuple({elem.get()});
 | 
			
		||||
  Call(&builder, callee, {ConstantLiteral(&builder, *elem)});
 | 
			
		||||
  auto tuple = LiteralUtil::MakeTuple({&elem});
 | 
			
		||||
  Call(&builder, callee, {ConstantLiteral(&builder, elem)});
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f));
 | 
			
		||||
  ComputeAndCompareTuple(&builder, tuple, {}, ErrorSpec(0.01f));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
@ -38,14 +38,14 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) {
 | 
			
		||||
  XlaBuilder builder("add_two_params");
 | 
			
		||||
  auto param_literal = LiteralUtil::CreateR1<float>({1.1f, 2.2f});
 | 
			
		||||
 | 
			
		||||
  auto p0 = Parameter(&builder, 0, param_literal->shape(), "param0");
 | 
			
		||||
  auto p1 = Parameter(&builder, 1, param_literal->shape(), "param1");
 | 
			
		||||
  auto p0 = Parameter(&builder, 0, param_literal.shape(), "param0");
 | 
			
		||||
  auto p1 = Parameter(&builder, 1, param_literal.shape(), "param1");
 | 
			
		||||
  Add(p0, p1);
 | 
			
		||||
 | 
			
		||||
  auto param0_data =
 | 
			
		||||
      client_->TransferToServer(*param_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(param_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto param1_data =
 | 
			
		||||
      client_->TransferToServer(*param_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(param_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  auto computation_status = builder.Build();
 | 
			
		||||
  ASSERT_IS_OK(computation_status.status());
 | 
			
		||||
@ -86,12 +86,12 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) {
 | 
			
		||||
  auto computation = computation_status.ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  auto f32_literal = LiteralUtil::CreateR0<float>(1.1f);
 | 
			
		||||
  auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto f32_data = client_->TransferToServer(f32_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto f32_4_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
 | 
			
		||||
  auto f32_4_data =
 | 
			
		||||
      client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(f32_4_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto u8_4_literal = LiteralUtil::CreateR1U8("hola");
 | 
			
		||||
  auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto u8_4_data = client_->TransferToServer(u8_4_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  // Match
 | 
			
		||||
  auto status = client_->Execute(
 | 
			
		||||
 | 
			
		||||
@ -101,7 +101,7 @@ StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
 | 
			
		||||
  return client_->Execute(computation, arguments, &execution_options_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
 | 
			
		||||
StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
 | 
			
		||||
    const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
 | 
			
		||||
    const Shape* shape_with_output_layout) {
 | 
			
		||||
  ExecutionOptions execution_options = execution_options_;
 | 
			
		||||
@ -113,7 +113,7 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
 | 
			
		||||
                                     &execution_options);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
 | 
			
		||||
StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
 | 
			
		||||
    XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
 | 
			
		||||
    const Shape* shape_with_output_layout) {
 | 
			
		||||
  // Build the computation, as a convenience.
 | 
			
		||||
@ -121,8 +121,7 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
 | 
			
		||||
  return ExecuteAndTransfer(computation, arguments, shape_with_output_layout);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::unique_ptr<Literal>>
 | 
			
		||||
ClientLibraryTestBase::ExecuteAndTransferReference(
 | 
			
		||||
StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransferReference(
 | 
			
		||||
    const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
 | 
			
		||||
    const Shape* shape_with_output_layout) {
 | 
			
		||||
  ExecutionOptions execution_options = execution_options_;
 | 
			
		||||
@ -148,15 +147,15 @@ string ClientLibraryTestBase::ExecuteToString(
 | 
			
		||||
  if (!result.ok()) {
 | 
			
		||||
    return result.status().ToString();
 | 
			
		||||
  } else {
 | 
			
		||||
    return result.ValueOrDie()->ToString();
 | 
			
		||||
    return result.ValueOrDie().ToString();
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ClientLibraryTestBase::ComputeAndCompareR1(
 | 
			
		||||
    XlaBuilder* builder, const tensorflow::core::Bitmap& expected,
 | 
			
		||||
    absl::Span<GlobalData* const> arguments) {
 | 
			
		||||
  std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1(expected);
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
 | 
			
		||||
  Literal expected_literal = LiteralUtil::CreateR1(expected);
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
 | 
			
		||||
                                                  arguments);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -182,7 +181,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
 | 
			
		||||
                             const string& error_message)>& verify_output) {
 | 
			
		||||
  // Try with no layout requirement.
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments));
 | 
			
		||||
  verify_output(*actual, "");
 | 
			
		||||
  verify_output(actual, "");
 | 
			
		||||
 | 
			
		||||
  // Try with all output layouts.
 | 
			
		||||
  std::vector<int64> minor_to_major(ShapeUtil::Rank(expected.shape()));
 | 
			
		||||
@ -193,7 +192,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
 | 
			
		||||
        AsInt64Slice(expected.shape().dimensions()), minor_to_major);
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(auto actual,
 | 
			
		||||
                        ExecuteAndTransfer(computation, arguments, &layout));
 | 
			
		||||
    verify_output(*actual,
 | 
			
		||||
    verify_output(actual,
 | 
			
		||||
                  absl::StrCat("Test with output layout: ",
 | 
			
		||||
                               ShapeUtil::HumanStringWithLayout(layout)));
 | 
			
		||||
  } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
 | 
			
		||||
@ -218,9 +217,9 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
 | 
			
		||||
      TF_ASSIGN_OR_RETURN(auto literal,
 | 
			
		||||
                          client_->Transfer(*arguments[index], nullptr));
 | 
			
		||||
      // Skip tuples because they don't have a rank.
 | 
			
		||||
      if (ShapeUtil::IsTuple(literal->shape())) {
 | 
			
		||||
      if (ShapeUtil::IsTuple(literal.shape())) {
 | 
			
		||||
        layout_strings.push_back(
 | 
			
		||||
            ShapeUtil::HumanStringWithLayout(literal->shape()));
 | 
			
		||||
            ShapeUtil::HumanStringWithLayout(literal.shape()));
 | 
			
		||||
        arguments_with_layout.push_back(arguments[index]);
 | 
			
		||||
        TF_RETURN_IF_ERROR(choose(index + 1));
 | 
			
		||||
        arguments_with_layout.pop_back();
 | 
			
		||||
@ -228,15 +227,15 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
 | 
			
		||||
        return Status::OK();
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      std::vector<int64> minor_to_major(ShapeUtil::Rank(literal->shape()));
 | 
			
		||||
      std::vector<int64> minor_to_major(ShapeUtil::Rank(literal.shape()));
 | 
			
		||||
      std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
 | 
			
		||||
      do {
 | 
			
		||||
        auto literal_relayout =
 | 
			
		||||
            literal->Relayout(LayoutUtil::MakeLayout(minor_to_major));
 | 
			
		||||
            literal.Relayout(LayoutUtil::MakeLayout(minor_to_major));
 | 
			
		||||
        layout_strings.push_back(
 | 
			
		||||
            ShapeUtil::HumanStringWithLayout(literal_relayout->shape()));
 | 
			
		||||
            ShapeUtil::HumanStringWithLayout(literal_relayout.shape()));
 | 
			
		||||
        TF_ASSIGN_OR_RETURN(auto data,
 | 
			
		||||
                            client_->TransferToServer(*literal_relayout));
 | 
			
		||||
                            client_->TransferToServer(literal_relayout));
 | 
			
		||||
        arguments_with_layout.push_back(data.get());
 | 
			
		||||
        TF_RETURN_IF_ERROR(choose(index + 1));
 | 
			
		||||
        arguments_with_layout.pop_back();
 | 
			
		||||
@ -256,7 +255,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
 | 
			
		||||
    for (const auto& str : layout_strings) {
 | 
			
		||||
      absl::StrAppend(&error_message, str, " ");
 | 
			
		||||
    }
 | 
			
		||||
    verify_output(*actual, error_message);
 | 
			
		||||
    verify_output(actual, error_message);
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
@ -290,11 +289,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
 | 
			
		||||
  // We allow using a float expected literal for a bfloat16 output. In this
 | 
			
		||||
  // case, we need to convert the expected literal to bfloat16.
 | 
			
		||||
  const Literal* expected_ptr = &expected;
 | 
			
		||||
  std::unique_ptr<Literal> converted_expected;
 | 
			
		||||
  Literal converted_expected;
 | 
			
		||||
  Shape layout_shape;
 | 
			
		||||
  if (use_bfloat16_) {
 | 
			
		||||
    converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
 | 
			
		||||
    expected_ptr = converted_expected.get();
 | 
			
		||||
    expected_ptr = &converted_expected;
 | 
			
		||||
    if (shape_with_layout != nullptr) {
 | 
			
		||||
      layout_shape = *shape_with_layout;
 | 
			
		||||
      ShapeUtil::ForEachMutableSubshape(
 | 
			
		||||
@ -319,7 +318,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
 | 
			
		||||
  }
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
 | 
			
		||||
                                                      shape_with_layout));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual));
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -346,11 +345,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
 | 
			
		||||
  // We allow using a float expected literal for a bfloat16 output. In this
 | 
			
		||||
  // case, we need to convert the expected literal to bfloat16.
 | 
			
		||||
  const Literal* expected_ptr = &expected;
 | 
			
		||||
  std::unique_ptr<Literal> converted_expected;
 | 
			
		||||
  Literal converted_expected;
 | 
			
		||||
  Shape layout_shape;
 | 
			
		||||
  if (use_bfloat16_) {
 | 
			
		||||
    converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
 | 
			
		||||
    expected_ptr = converted_expected.get();
 | 
			
		||||
    expected_ptr = &converted_expected;
 | 
			
		||||
    if (shape_with_layout != nullptr) {
 | 
			
		||||
      layout_shape = *shape_with_layout;
 | 
			
		||||
      ShapeUtil::ForEachMutableSubshape(
 | 
			
		||||
@ -376,7 +375,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
 | 
			
		||||
  }
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
 | 
			
		||||
                                                      shape_with_layout));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error));
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -391,12 +390,12 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
 | 
			
		||||
  auto actual = actual_status.ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  // Turn the expected value into a literal.
 | 
			
		||||
  std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1U8(expected);
 | 
			
		||||
  Literal expected_literal = LiteralUtil::CreateR1U8(expected);
 | 
			
		||||
 | 
			
		||||
  VLOG(1) << "expected: " << expected_literal->ToString();
 | 
			
		||||
  VLOG(1) << "actual:   " << actual->ToString();
 | 
			
		||||
  VLOG(1) << "expected: " << expected_literal.ToString();
 | 
			
		||||
  VLOG(1) << "actual:   " << actual.ToString();
 | 
			
		||||
 | 
			
		||||
  EXPECT_EQ(expected, actual->GetR1U8AsString());
 | 
			
		||||
  EXPECT_EQ(expected, actual.GetR1U8AsString());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ClientLibraryTestBase::ComputeAndCompareTuple(
 | 
			
		||||
@ -408,7 +407,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
  auto actual = actual_status.ConsumeValueOrDie();
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ClientLibraryTestBase::ComputeAndCompareTuple(
 | 
			
		||||
@ -420,7 +419,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
  auto actual = actual_status.ConsumeValueOrDie();
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, error));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ClientLibraryTestBase::ComputeAndCompare(
 | 
			
		||||
@ -430,9 +429,9 @@ void ClientLibraryTestBase::ComputeAndCompare(
 | 
			
		||||
  if (!status_or_data.ok()) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
  std::unique_ptr<Literal> reference, result;
 | 
			
		||||
  Literal reference, result;
 | 
			
		||||
  std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Equal(reference, result));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ClientLibraryTestBase::ComputeAndCompare(
 | 
			
		||||
@ -442,12 +441,12 @@ void ClientLibraryTestBase::ComputeAndCompare(
 | 
			
		||||
  if (!status_or_data.ok()) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
  std::unique_ptr<Literal> reference, result;
 | 
			
		||||
  Literal reference, result;
 | 
			
		||||
  std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Near(reference, result, error));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
 | 
			
		||||
StatusOr<std::pair<Literal, Literal>>
 | 
			
		||||
ClientLibraryTestBase::ComputeValueAndReference(
 | 
			
		||||
    XlaBuilder* builder, absl::Span<const Literal> arguments) {
 | 
			
		||||
  // Transfer the arguments to the executor service. We put the unique_ptr's
 | 
			
		||||
@ -569,8 +568,8 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
 | 
			
		||||
XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
 | 
			
		||||
                                                       XlaBuilder* builder) {
 | 
			
		||||
  return ConstantLiteral(builder, use_bfloat16_
 | 
			
		||||
                                      ? *LiteralUtil::ConvertF32ToBF16(literal)
 | 
			
		||||
                                      : literal);
 | 
			
		||||
                                      ? LiteralUtil::ConvertF32ToBF16(literal)
 | 
			
		||||
                                      : LiteralSlice(literal));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<GlobalData>
 | 
			
		||||
@ -600,7 +599,7 @@ Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) {
 | 
			
		||||
Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16(
 | 
			
		||||
    const Literal& literal) {
 | 
			
		||||
  if (use_bfloat16_) {
 | 
			
		||||
    return std::move(*LiteralUtil::ConvertF32ToBF16(literal));
 | 
			
		||||
    return LiteralUtil::ConvertF32ToBF16(literal);
 | 
			
		||||
  }
 | 
			
		||||
  return literal.Clone();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -95,11 +95,11 @@ class ClientLibraryTestBase : public ::testing::Test {
 | 
			
		||||
  StatusOr<std::unique_ptr<GlobalData>> Execute(
 | 
			
		||||
      XlaBuilder* builder, absl::Span<GlobalData* const> arguments);
 | 
			
		||||
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
 | 
			
		||||
  StatusOr<Literal> ExecuteAndTransfer(
 | 
			
		||||
      XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
 | 
			
		||||
      const Shape* shape_with_output_layout = nullptr);
 | 
			
		||||
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
 | 
			
		||||
  StatusOr<Literal> ExecuteAndTransfer(
 | 
			
		||||
      const XlaComputation& computation,
 | 
			
		||||
      absl::Span<GlobalData* const> arguments,
 | 
			
		||||
      const Shape* shape_with_output_layout = nullptr);
 | 
			
		||||
@ -107,7 +107,7 @@ class ClientLibraryTestBase : public ::testing::Test {
 | 
			
		||||
  // This executes the computation via the reference client (which connects a
 | 
			
		||||
  // interpreter backend). The result is used as the expected values of the
 | 
			
		||||
  // computation.
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> ExecuteAndTransferReference(
 | 
			
		||||
  StatusOr<Literal> ExecuteAndTransferReference(
 | 
			
		||||
      const XlaComputation& computation,
 | 
			
		||||
      absl::Span<GlobalData* const> arguments,
 | 
			
		||||
      const Shape* shape_with_output_layout = nullptr);
 | 
			
		||||
@ -282,7 +282,7 @@ class ClientLibraryTestBase : public ::testing::Test {
 | 
			
		||||
 | 
			
		||||
  template <class T>
 | 
			
		||||
  XlaOp AddParam(const Array<T>& argument, XlaBuilder* builder) {
 | 
			
		||||
    return AddParam(*LiteralUtil::CreateFromArray(argument), builder);
 | 
			
		||||
    return AddParam(LiteralUtil::CreateFromArray(argument), builder);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Creates a constant instruction with the given literal. When the
 | 
			
		||||
@ -297,14 +297,14 @@ class ClientLibraryTestBase : public ::testing::Test {
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  XlaOp CreateConstantFromArray(const Array<NativeT>& array,
 | 
			
		||||
                                XlaBuilder* builder) {
 | 
			
		||||
    return CreateConstantFromLiteral(*LiteralUtil::CreateFromArray(array),
 | 
			
		||||
    return CreateConstantFromLiteral(LiteralUtil::CreateFromArray(array),
 | 
			
		||||
                                     builder);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Same as CreateConstantFromArray, but for scalars.
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) {
 | 
			
		||||
    return CreateConstantFromLiteral(*LiteralUtil::CreateR0<NativeT>(value),
 | 
			
		||||
    return CreateConstantFromLiteral(LiteralUtil::CreateR0<NativeT>(value),
 | 
			
		||||
                                     builder);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -375,9 +375,8 @@ class ClientLibraryTestBase : public ::testing::Test {
 | 
			
		||||
  // Executes the computation and calculates the expected reference value using
 | 
			
		||||
  // the reference client. Returns two literals in the order of (expected,
 | 
			
		||||
  // actual).
 | 
			
		||||
  StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
 | 
			
		||||
  ComputeValueAndReference(XlaBuilder* builder,
 | 
			
		||||
                           absl::Span<const Literal> arguments);
 | 
			
		||||
  StatusOr<std::pair<Literal, Literal>> ComputeValueAndReference(
 | 
			
		||||
      XlaBuilder* builder, absl::Span<const Literal> arguments);
 | 
			
		||||
 | 
			
		||||
  Client* client_;
 | 
			
		||||
  Client* ref_client_;  // To compute reference result.
 | 
			
		||||
@ -412,9 +411,8 @@ template <typename NativeT>
 | 
			
		||||
void ClientLibraryTestBase::ComputeAndCompareR0(
 | 
			
		||||
    XlaBuilder* builder, NativeT expected,
 | 
			
		||||
    absl::Span<GlobalData* const> arguments) {
 | 
			
		||||
  std::unique_ptr<Literal> expected_literal =
 | 
			
		||||
      LiteralUtil::CreateR0<NativeT>(expected);
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
 | 
			
		||||
  Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
 | 
			
		||||
                                                  arguments);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -428,9 +426,8 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
 | 
			
		||||
                    std::is_same<NativeT, half>::value ||
 | 
			
		||||
                    std::is_same<NativeT, complex64>::value,
 | 
			
		||||
                "Float or complex type required when specifying an ErrorSpec");
 | 
			
		||||
  std::unique_ptr<Literal> expected_literal =
 | 
			
		||||
      LiteralUtil::CreateR0<NativeT>(expected);
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
 | 
			
		||||
  Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
 | 
			
		||||
                                                  arguments, error);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -438,9 +435,8 @@ template <typename NativeT>
 | 
			
		||||
void ClientLibraryTestBase::ComputeAndCompareR1(
 | 
			
		||||
    XlaBuilder* builder, absl::Span<const NativeT> expected,
 | 
			
		||||
    absl::Span<GlobalData* const> arguments) {
 | 
			
		||||
  std::unique_ptr<Literal> expected_literal =
 | 
			
		||||
      LiteralUtil::CreateR1<NativeT>(expected);
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
 | 
			
		||||
  Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
 | 
			
		||||
                                                  arguments);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -454,9 +450,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
 | 
			
		||||
                    std::is_same<NativeT, half>::value ||
 | 
			
		||||
                    std::is_same<NativeT, complex64>::value,
 | 
			
		||||
                "Float or complex type required when specifying an ErrorSpec");
 | 
			
		||||
  std::unique_ptr<Literal> expected_literal =
 | 
			
		||||
      LiteralUtil::CreateR1<NativeT>(expected);
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
 | 
			
		||||
  Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
 | 
			
		||||
                                                  arguments, error);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -464,9 +459,9 @@ template <typename NativeT>
 | 
			
		||||
void ClientLibraryTestBase::ComputeAndCompareR2(
 | 
			
		||||
    XlaBuilder* builder, const Array2D<NativeT>& expected,
 | 
			
		||||
    absl::Span<GlobalData* const> arguments) {
 | 
			
		||||
  std::unique_ptr<Literal> expected_literal =
 | 
			
		||||
  Literal expected_literal =
 | 
			
		||||
      LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
 | 
			
		||||
                                                  arguments);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -480,9 +475,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2(
 | 
			
		||||
                    std::is_same<NativeT, half>::value ||
 | 
			
		||||
                    std::is_same<NativeT, complex64>::value,
 | 
			
		||||
                "Float or complex type required when specifying an ErrorSpec");
 | 
			
		||||
  std::unique_ptr<Literal> expected_literal =
 | 
			
		||||
  Literal expected_literal =
 | 
			
		||||
      LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
 | 
			
		||||
                                                  arguments, error);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -490,9 +485,9 @@ template <typename NativeT>
 | 
			
		||||
void ClientLibraryTestBase::ComputeAndCompareR3(
 | 
			
		||||
    XlaBuilder* builder, const Array3D<NativeT>& expected,
 | 
			
		||||
    absl::Span<GlobalData* const> arguments) {
 | 
			
		||||
  std::unique_ptr<Literal> expected_literal =
 | 
			
		||||
  Literal expected_literal =
 | 
			
		||||
      LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
 | 
			
		||||
                                                  arguments);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -506,9 +501,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3(
 | 
			
		||||
                    std::is_same<NativeT, half>::value ||
 | 
			
		||||
                    std::is_same<NativeT, complex64>::value,
 | 
			
		||||
                "Float or complex type required when specifying an ErrorSpec");
 | 
			
		||||
  std::unique_ptr<Literal> expected_literal =
 | 
			
		||||
  Literal expected_literal =
 | 
			
		||||
      LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
 | 
			
		||||
                                                  arguments, error);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -516,9 +511,9 @@ template <typename NativeT>
 | 
			
		||||
void ClientLibraryTestBase::ComputeAndCompareR4(
 | 
			
		||||
    XlaBuilder* builder, const Array4D<NativeT>& expected,
 | 
			
		||||
    absl::Span<GlobalData* const> arguments) {
 | 
			
		||||
  std::unique_ptr<Literal> expected_literal =
 | 
			
		||||
  Literal expected_literal =
 | 
			
		||||
      LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
 | 
			
		||||
                                                  arguments);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -532,9 +527,9 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
 | 
			
		||||
                    std::is_same<NativeT, half>::value ||
 | 
			
		||||
                    std::is_same<NativeT, complex64>::value,
 | 
			
		||||
                "Float or complex type required when specifying an ErrorSpec");
 | 
			
		||||
  std::unique_ptr<Literal> expected_literal =
 | 
			
		||||
  Literal expected_literal =
 | 
			
		||||
      LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
 | 
			
		||||
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
 | 
			
		||||
                                                  arguments, error);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -542,13 +537,13 @@ template <typename NativeT>
 | 
			
		||||
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
 | 
			
		||||
    NativeT value, int64 parameter_number, const string& name,
 | 
			
		||||
    XlaBuilder* builder, XlaOp* data_handle) {
 | 
			
		||||
  std::unique_ptr<Literal> literal = LiteralUtil::CreateR0(value);
 | 
			
		||||
  if (use_bfloat16_ && literal->shape().element_type() == F32) {
 | 
			
		||||
    literal = LiteralUtil::ConvertF32ToBF16(*literal);
 | 
			
		||||
  Literal literal = LiteralUtil::CreateR0(value);
 | 
			
		||||
  if (use_bfloat16_ && literal.shape().element_type() == F32) {
 | 
			
		||||
    literal = LiteralUtil::ConvertF32ToBF16(literal);
 | 
			
		||||
  }
 | 
			
		||||
  std::unique_ptr<GlobalData> data =
 | 
			
		||||
      client_->TransferToServer(*literal).ConsumeValueOrDie();
 | 
			
		||||
  *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
 | 
			
		||||
      client_->TransferToServer(literal).ConsumeValueOrDie();
 | 
			
		||||
  *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
 | 
			
		||||
  return data;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -556,13 +551,13 @@ template <typename NativeT>
 | 
			
		||||
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
 | 
			
		||||
    absl::Span<const NativeT> values, int64 parameter_number,
 | 
			
		||||
    const string& name, XlaBuilder* builder, XlaOp* data_handle) {
 | 
			
		||||
  std::unique_ptr<Literal> literal = LiteralUtil::CreateR1(values);
 | 
			
		||||
  if (use_bfloat16_ && literal->shape().element_type() == F32) {
 | 
			
		||||
    literal = LiteralUtil::ConvertF32ToBF16(*literal);
 | 
			
		||||
  Literal literal = LiteralUtil::CreateR1(values);
 | 
			
		||||
  if (use_bfloat16_ && literal.shape().element_type() == F32) {
 | 
			
		||||
    literal = LiteralUtil::ConvertF32ToBF16(literal);
 | 
			
		||||
  }
 | 
			
		||||
  std::unique_ptr<GlobalData> data =
 | 
			
		||||
      client_->TransferToServer(*literal).ConsumeValueOrDie();
 | 
			
		||||
  *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
 | 
			
		||||
      client_->TransferToServer(literal).ConsumeValueOrDie();
 | 
			
		||||
  *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
 | 
			
		||||
  return data;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -570,13 +565,13 @@ template <typename NativeT>
 | 
			
		||||
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
 | 
			
		||||
    const Array2D<NativeT>& array_2d, int64 parameter_number,
 | 
			
		||||
    const string& name, XlaBuilder* builder, XlaOp* data_handle) {
 | 
			
		||||
  std::unique_ptr<Literal> literal = LiteralUtil::CreateR2FromArray2D(array_2d);
 | 
			
		||||
  if (use_bfloat16_ && literal->shape().element_type() == F32) {
 | 
			
		||||
    literal = LiteralUtil::ConvertF32ToBF16(*literal);
 | 
			
		||||
  Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d);
 | 
			
		||||
  if (use_bfloat16_ && literal.shape().element_type() == F32) {
 | 
			
		||||
    literal = LiteralUtil::ConvertF32ToBF16(literal);
 | 
			
		||||
  }
 | 
			
		||||
  std::unique_ptr<GlobalData> data =
 | 
			
		||||
      client_->TransferToServer(*literal).ConsumeValueOrDie();
 | 
			
		||||
  *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
 | 
			
		||||
      client_->TransferToServer(literal).ConsumeValueOrDie();
 | 
			
		||||
  *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
 | 
			
		||||
  return data;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -584,13 +579,13 @@ template <typename NativeT>
 | 
			
		||||
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
 | 
			
		||||
    const Array3D<NativeT>& array_3d, int64 parameter_number,
 | 
			
		||||
    const string& name, XlaBuilder* builder, XlaOp* data_handle) {
 | 
			
		||||
  std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(array_3d);
 | 
			
		||||
  if (use_bfloat16_ && literal->shape().element_type() == F32) {
 | 
			
		||||
    literal = LiteralUtil::ConvertF32ToBF16(*literal);
 | 
			
		||||
  Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d);
 | 
			
		||||
  if (use_bfloat16_ && literal.shape().element_type() == F32) {
 | 
			
		||||
    literal = LiteralUtil::ConvertF32ToBF16(literal);
 | 
			
		||||
  }
 | 
			
		||||
  std::unique_ptr<GlobalData> data =
 | 
			
		||||
      client_->TransferToServer(*literal).ConsumeValueOrDie();
 | 
			
		||||
  *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
 | 
			
		||||
      client_->TransferToServer(literal).ConsumeValueOrDie();
 | 
			
		||||
  *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
 | 
			
		||||
  return data;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,16 +55,15 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) {
 | 
			
		||||
          std::unique_ptr<GlobalData> data,
 | 
			
		||||
          client_->Execute(computation, {}, &execution_options));
 | 
			
		||||
 | 
			
		||||
      std::unique_ptr<Literal> expected_literal =
 | 
			
		||||
          LiteralUtil::CreateR2WithLayout<int32>(
 | 
			
		||||
              {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
 | 
			
		||||
      Literal expected_literal = LiteralUtil::CreateR2WithLayout<int32>(
 | 
			
		||||
          {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
 | 
			
		||||
 | 
			
		||||
      TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
          auto computed, client_->Transfer(*data, &expected_literal->shape()));
 | 
			
		||||
          auto computed, client_->Transfer(*data, &expected_literal.shape()));
 | 
			
		||||
 | 
			
		||||
      ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
 | 
			
		||||
          expected_literal->shape(), computed->shape()));
 | 
			
		||||
      EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
 | 
			
		||||
          expected_literal.shape(), computed.shape()));
 | 
			
		||||
      EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
@ -91,19 +90,19 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) {
 | 
			
		||||
      auto result,
 | 
			
		||||
      client_->ExecuteAndTransfer(computation, {}, &execution_options));
 | 
			
		||||
  LiteralTestUtil::ExpectR2Equal<int32>({{1, 2}, {3, 4}},
 | 
			
		||||
                                        LiteralSlice(*result, {0}));
 | 
			
		||||
                                        LiteralSlice(result, {0}));
 | 
			
		||||
  LiteralTestUtil::ExpectR2Equal<int32>({{10, 20}, {30, 40}},
 | 
			
		||||
                                        LiteralSlice(*result, {1}));
 | 
			
		||||
                                        LiteralSlice(result, {1}));
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(ShapeUtil::IsTuple(result->shape()));
 | 
			
		||||
  EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape()));
 | 
			
		||||
  EXPECT_TRUE(ShapeUtil::IsTuple(result.shape()));
 | 
			
		||||
  EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.shape()));
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(ShapeUtil::Equal(
 | 
			
		||||
      ShapeUtil::GetTupleElementShape(result->shape(), 0),
 | 
			
		||||
      ShapeUtil::GetTupleElementShape(result.shape(), 0),
 | 
			
		||||
      ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
 | 
			
		||||
                                     /*minor_to_major=*/{0, 1})));
 | 
			
		||||
  EXPECT_TRUE(ShapeUtil::Equal(
 | 
			
		||||
      ShapeUtil::GetTupleElementShape(result->shape(), 1),
 | 
			
		||||
      ShapeUtil::GetTupleElementShape(result.shape(), 1),
 | 
			
		||||
      ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
 | 
			
		||||
                                     /*minor_to_major=*/{1, 0})));
 | 
			
		||||
}
 | 
			
		||||
@ -114,7 +113,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
 | 
			
		||||
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> const_arg,
 | 
			
		||||
                          client_->TransferToServer(
 | 
			
		||||
                              *LiteralUtil::CreateR2<int32>({{5, 6}, {7, 8}})));
 | 
			
		||||
                              LiteralUtil::CreateR2<int32>({{5, 6}, {7, 8}})));
 | 
			
		||||
 | 
			
		||||
  XlaBuilder b(TestName() + ".add");
 | 
			
		||||
  Add(Parameter(&b, 0, shape, "param_0"),
 | 
			
		||||
@ -140,9 +139,9 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
 | 
			
		||||
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      auto result_literal,
 | 
			
		||||
      client_->Transfer(*results[0], &expected_result->shape()));
 | 
			
		||||
      client_->Transfer(*results[0], &expected_result.shape()));
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Equal(expected_result, result_literal));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
@ -42,14 +42,14 @@ class CompilationCacheTest : public ClientLibraryTestBase {
 | 
			
		||||
                               absl::Span<GlobalData* const> arguments,
 | 
			
		||||
                               float expected_result, bool expect_cache_hit) {
 | 
			
		||||
    ExecutionProfile execution_profile;
 | 
			
		||||
    std::unique_ptr<Literal> result =
 | 
			
		||||
    Literal result =
 | 
			
		||||
        client_
 | 
			
		||||
            ->ExecuteAndTransfer(computation, arguments,
 | 
			
		||||
                                 /*execution_options=*/&execution_options_,
 | 
			
		||||
                                 &execution_profile)
 | 
			
		||||
            .ConsumeValueOrDie();
 | 
			
		||||
    EXPECT_TRUE(LiteralTestUtil::Near(
 | 
			
		||||
        *LiteralUtil::CreateR0<float>(expected_result), *result, error_spec_));
 | 
			
		||||
        LiteralUtil::CreateR0<float>(expected_result), result, error_spec_));
 | 
			
		||||
    EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -63,10 +63,9 @@ class CompilationCacheTest : public ClientLibraryTestBase {
 | 
			
		||||
                           ->Execute(computation, arguments,
 | 
			
		||||
                                     &execution_options_, &execution_profile)
 | 
			
		||||
                           .ConsumeValueOrDie();
 | 
			
		||||
    std::unique_ptr<Literal> result =
 | 
			
		||||
        client_->Transfer(*data_handle).ConsumeValueOrDie();
 | 
			
		||||
    Literal result = client_->Transfer(*data_handle).ConsumeValueOrDie();
 | 
			
		||||
    EXPECT_TRUE(LiteralTestUtil::Near(
 | 
			
		||||
        *LiteralUtil::CreateR2<float>(expected_result), *result, error_spec_));
 | 
			
		||||
        LiteralUtil::CreateR2<float>(expected_result), result, error_spec_));
 | 
			
		||||
    EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -88,13 +87,13 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) {
 | 
			
		||||
XLA_TEST_F(CompilationCacheTest,
 | 
			
		||||
           DISABLED_ComputationCalledWithDifferentParameters) {
 | 
			
		||||
  std::unique_ptr<GlobalData> data_42 =
 | 
			
		||||
      client_->TransferToServer(*LiteralUtil::CreateR0<float>(42.0f))
 | 
			
		||||
      client_->TransferToServer(LiteralUtil::CreateR0<float>(42.0f))
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<GlobalData> data_123 =
 | 
			
		||||
      client_->TransferToServer(*LiteralUtil::CreateR0<float>(123.0f))
 | 
			
		||||
      client_->TransferToServer(LiteralUtil::CreateR0<float>(123.0f))
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<GlobalData> data_456 =
 | 
			
		||||
      client_->TransferToServer(*LiteralUtil::CreateR0<float>(456.0f))
 | 
			
		||||
      client_->TransferToServer(LiteralUtil::CreateR0<float>(456.0f))
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
@ -145,12 +144,12 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) {
 | 
			
		||||
  auto rowmaj_array = LiteralUtil::CreateR2WithLayout(
 | 
			
		||||
      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0}));
 | 
			
		||||
  auto rowmaj_handle =
 | 
			
		||||
      client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(rowmaj_array).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  auto colmaj_array = LiteralUtil::CreateR2WithLayout(
 | 
			
		||||
      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}));
 | 
			
		||||
  auto colmaj_handle =
 | 
			
		||||
      client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(colmaj_array).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
  Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0");
 | 
			
		||||
 | 
			
		||||
@ -69,9 +69,9 @@ class ComputeConstantTest : public ::testing::Test {
 | 
			
		||||
    LOG(FATAL) << "invalid client_type value";
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  StatusOr<std::unique_ptr<Literal>> ComputeConstantLiteral(
 | 
			
		||||
      Client* client, const XlaOp& operand, XlaBuilder* builder,
 | 
			
		||||
      Layout* output_layout = nullptr) {
 | 
			
		||||
  StatusOr<Literal> ComputeConstantLiteral(Client* client, const XlaOp& operand,
 | 
			
		||||
                                           XlaBuilder* builder,
 | 
			
		||||
                                           Layout* output_layout = nullptr) {
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand));
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(auto computed,
 | 
			
		||||
                        client->ComputeConstant(subgraph, output_layout));
 | 
			
		||||
@ -83,7 +83,7 @@ class ComputeConstantTest : public ::testing::Test {
 | 
			
		||||
                                         XlaBuilder* builder) {
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand,
 | 
			
		||||
                                                             builder, nullptr));
 | 
			
		||||
    return literal->Get<Scalar>({});
 | 
			
		||||
    return literal.Get<Scalar>({});
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool IsConstant(const XlaOp& operand, XlaBuilder* builder) {
 | 
			
		||||
@ -206,9 +206,8 @@ TEST_F(ComputeConstantTest, NonScalarAdd) {
 | 
			
		||||
 | 
			
		||||
    TF_ASSERT_OK_AND_ASSIGN(auto computed,
 | 
			
		||||
                            ComputeConstantLiteral(client, computation, &b));
 | 
			
		||||
    std::unique_ptr<Literal> expected_literal =
 | 
			
		||||
        LiteralUtil::CreateR1<int32>({4, 6});
 | 
			
		||||
    EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
 | 
			
		||||
    Literal expected_literal = LiteralUtil::CreateR1<int32>({4, 6});
 | 
			
		||||
    EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -221,8 +220,8 @@ TEST_F(ComputeConstantTest, IntegerDivide) {
 | 
			
		||||
 | 
			
		||||
    TF_ASSERT_OK_AND_ASSIGN(auto computed,
 | 
			
		||||
                            ComputeConstantLiteral(client, computation, &b));
 | 
			
		||||
    std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR0<int32>(5);
 | 
			
		||||
    EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
 | 
			
		||||
    Literal expected_literal = LiteralUtil::CreateR0<int32>(5);
 | 
			
		||||
    EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -241,12 +240,11 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
 | 
			
		||||
                                 ConstantR2<int32>(&b, {{10, 20}, {30, 40}})),
 | 
			
		||||
                             &b, &layout_proto));
 | 
			
		||||
 | 
			
		||||
      std::unique_ptr<Literal> expected_literal =
 | 
			
		||||
          LiteralUtil::CreateR2WithLayout<int32>(
 | 
			
		||||
              {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
 | 
			
		||||
      Literal expected_literal = LiteralUtil::CreateR2WithLayout<int32>(
 | 
			
		||||
          {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
 | 
			
		||||
      ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
 | 
			
		||||
          expected_literal->shape(), computed->shape()));
 | 
			
		||||
      EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
 | 
			
		||||
          expected_literal.shape(), computed.shape()));
 | 
			
		||||
      EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -536,8 +536,8 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) {
 | 
			
		||||
  auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
 | 
			
		||||
  auto x_literal = LiteralUtil::CreateR0<float>(2.f);
 | 
			
		||||
  auto y_literal = LiteralUtil::CreateR0<float>(3.f);
 | 
			
		||||
  auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
  auto x = Parameter(&builder, 0, f32_scalar, "x");
 | 
			
		||||
@ -559,12 +559,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) {
 | 
			
		||||
  auto x_literal = LiteralUtil::CreateR1<float>({2.0f, 3.0f, 5.0f, 6.0f});
 | 
			
		||||
  auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
 | 
			
		||||
  auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
 | 
			
		||||
  auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
  auto x = Parameter(&builder, 0, x_literal->shape(), "x");
 | 
			
		||||
  auto x = Parameter(&builder, 0, x_literal.shape(), "x");
 | 
			
		||||
  auto y = Parameter(&builder, 1, f32_scalar, "y");
 | 
			
		||||
  auto z = Parameter(&builder, 2, f32_scalar, "z");
 | 
			
		||||
  auto bcast = Broadcast(y, {5});
 | 
			
		||||
@ -587,12 +587,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) {
 | 
			
		||||
  auto x_literal = LiteralUtil::CreateR3FromArray3D<float>(x3d);
 | 
			
		||||
  auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
 | 
			
		||||
  auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
 | 
			
		||||
  auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
  auto x = Parameter(&builder, 0, x_literal->shape(), "x");
 | 
			
		||||
  auto x = Parameter(&builder, 0, x_literal.shape(), "x");
 | 
			
		||||
  auto y = Parameter(&builder, 1, f32_scalar, "y");
 | 
			
		||||
  auto z = Parameter(&builder, 2, f32_scalar, "y");
 | 
			
		||||
  auto y_bcast = Broadcast(y, {1, 5, 7});
 | 
			
		||||
 | 
			
		||||
@ -359,8 +359,8 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareTuple(
 | 
			
		||||
      &builder,
 | 
			
		||||
      *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(12.0f).get(),
 | 
			
		||||
                               LiteralUtil::CreateR0<float>(25.0f).get()}),
 | 
			
		||||
      LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<float>(12.0f),
 | 
			
		||||
                                        LiteralUtil::CreateR0<float>(25.0f)}),
 | 
			
		||||
      {pred_arg.get()}, error_spec_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -375,12 +375,11 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
 | 
			
		||||
  Conditional(pred, operands, CreateR1TupleCeilComputation(), operands,
 | 
			
		||||
              CreateR1TupleFloorComputation());
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareTuple(
 | 
			
		||||
      &builder,
 | 
			
		||||
      *LiteralUtil::MakeTuple(
 | 
			
		||||
          {LiteralUtil::CreateR1<float>({13.0f, 16.0f}).get(),
 | 
			
		||||
           LiteralUtil::CreateR1<float>({26.0f, 30.0f}).get()}),
 | 
			
		||||
      {pred_arg.get()}, error_spec_);
 | 
			
		||||
  ComputeAndCompareTuple(&builder,
 | 
			
		||||
                         LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
                             {LiteralUtil::CreateR1<float>({13.0f, 16.0f}),
 | 
			
		||||
                              LiteralUtil::CreateR1<float>({26.0f, 30.0f})}),
 | 
			
		||||
                         {pred_arg.get()}, error_spec_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Test true and false computations that return a tuple of a predicate, a
 | 
			
		||||
@ -415,13 +414,12 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
 | 
			
		||||
  Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
 | 
			
		||||
              false_builder_result.ConsumeValueOrDie());
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareTuple(
 | 
			
		||||
      &builder,
 | 
			
		||||
      *LiteralUtil::MakeTuple(
 | 
			
		||||
          {LiteralUtil::CreateR0<bool>(true).get(),
 | 
			
		||||
           LiteralUtil::CreateR0<float>(12.2f).get(),
 | 
			
		||||
           LiteralUtil::CreateR1<float>({12.8f, 14.6f}).get()}),
 | 
			
		||||
      {pred_arg.get()}, error_spec_);
 | 
			
		||||
  ComputeAndCompareTuple(&builder,
 | 
			
		||||
                         LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
                             {LiteralUtil::CreateR0<bool>(true),
 | 
			
		||||
                              LiteralUtil::CreateR0<float>(12.2f),
 | 
			
		||||
                              LiteralUtil::CreateR1<float>({12.8f, 14.6f})}),
 | 
			
		||||
                         {pred_arg.get()}, error_spec_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Test true and false computations that return a nested tuple.
 | 
			
		||||
@ -463,15 +461,13 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareTuple(
 | 
			
		||||
      &builder,
 | 
			
		||||
      *LiteralUtil::MakeTuple(
 | 
			
		||||
          {LiteralUtil::MakeTuple(
 | 
			
		||||
               {LiteralUtil::CreateR0<float>(46.6f).get(),
 | 
			
		||||
                LiteralUtil::CreateR1<float>({54.4f, 58.4f}).get()})
 | 
			
		||||
               .get(),
 | 
			
		||||
           LiteralUtil::MakeTuple(
 | 
			
		||||
               {LiteralUtil::CreateR1<float>({62.1f, 67.4f}).get(),
 | 
			
		||||
                LiteralUtil::CreateR0<float>(9.3f).get()})
 | 
			
		||||
               .get()}),
 | 
			
		||||
      LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
          {LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
               {LiteralUtil::CreateR0<float>(46.6f),
 | 
			
		||||
                LiteralUtil::CreateR1<float>({54.4f, 58.4f})}),
 | 
			
		||||
           LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
               {LiteralUtil::CreateR1<float>({62.1f, 67.4f}),
 | 
			
		||||
                LiteralUtil::CreateR0<float>(9.3f)})}),
 | 
			
		||||
      {pred_arg.get()}, error_spec_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -633,8 +629,8 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
 | 
			
		||||
 | 
			
		||||
    ComputeAndCompareTuple(
 | 
			
		||||
        &builder,
 | 
			
		||||
        *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(a).get(),
 | 
			
		||||
                                 LiteralUtil::CreateR0<float>(b).get()}),
 | 
			
		||||
        LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
            {LiteralUtil::CreateR0<float>(a), LiteralUtil::CreateR0<float>(b)}),
 | 
			
		||||
        {x_arg.get(), y_arg.get()}, error_spec_);
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
@ -669,10 +665,10 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
 | 
			
		||||
  {
 | 
			
		||||
    // Pred is true case.
 | 
			
		||||
    std::vector<Literal> args;
 | 
			
		||||
    args.push_back(std::move(
 | 
			
		||||
        *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(123).get(),
 | 
			
		||||
                                 LiteralUtil::CreateR0<int32>(-42).get()})));
 | 
			
		||||
    args.push_back(std::move(*LiteralUtil::CreateR0<bool>(true)));
 | 
			
		||||
    args.push_back(
 | 
			
		||||
        LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
 | 
			
		||||
                                          LiteralUtil::CreateR0<int32>(-42)}));
 | 
			
		||||
    args.push_back(LiteralUtil::CreateR0<bool>(true));
 | 
			
		||||
    XlaBuilder builder(TestName() + ".main");
 | 
			
		||||
    auto p = Parameter(&builder, 0, tuple2, "p0");
 | 
			
		||||
    auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
 | 
			
		||||
@ -682,10 +678,10 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
 | 
			
		||||
  {
 | 
			
		||||
    // Pred is false case.
 | 
			
		||||
    std::vector<Literal> args;
 | 
			
		||||
    args.push_back(std::move(
 | 
			
		||||
        *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(123).get(),
 | 
			
		||||
                                 LiteralUtil::CreateR0<int32>(-42).get()})));
 | 
			
		||||
    args.push_back(std::move(*LiteralUtil::CreateR0<bool>(false)));
 | 
			
		||||
    args.push_back(
 | 
			
		||||
        LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
 | 
			
		||||
                                          LiteralUtil::CreateR0<int32>(-42)}));
 | 
			
		||||
    args.push_back(LiteralUtil::CreateR0<bool>(false));
 | 
			
		||||
    XlaBuilder builder(TestName() + ".main");
 | 
			
		||||
    auto p = Parameter(&builder, 0, tuple2, "p0");
 | 
			
		||||
    auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
 | 
			
		||||
 | 
			
		||||
@ -110,7 +110,7 @@ TEST_F(ConstantsTest, Small_2x2) {
 | 
			
		||||
 | 
			
		||||
TEST_F(ConstantsTest, Empty_3x0x2) {
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
  ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D<float>(
 | 
			
		||||
  ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(
 | 
			
		||||
                                Array3D<float>(3, 0, 2)));
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 2), {});
 | 
			
		||||
@ -126,7 +126,7 @@ TEST_F(ConstantsTest, Small_2x2x2) {
 | 
			
		||||
      {{5.f, 6.f},   // y0
 | 
			
		||||
       {7.f, 8.f}},  // y1
 | 
			
		||||
  });
 | 
			
		||||
  ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D<float>(array3d));
 | 
			
		||||
  ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(array3d));
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareR3<float>(&builder, array3d, {});
 | 
			
		||||
}
 | 
			
		||||
@ -140,12 +140,11 @@ TEST_F(ConstantsTest, Small_3x2x1x1) {
 | 
			
		||||
      {5.0f, 4.4f},   // p2
 | 
			
		||||
  });
 | 
			
		||||
  input_array.FillWithPZ(pz);
 | 
			
		||||
  std::unique_ptr<Literal> input_literal =
 | 
			
		||||
      LiteralUtil::CreateR4FromArray4D(input_array);
 | 
			
		||||
  Literal input_literal = LiteralUtil::CreateR4FromArray4D(input_array);
 | 
			
		||||
 | 
			
		||||
  {
 | 
			
		||||
    XlaBuilder builder(TestName());
 | 
			
		||||
    ConstantLiteral(&builder, *input_literal);
 | 
			
		||||
    ConstantLiteral(&builder, input_literal);
 | 
			
		||||
    ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -159,23 +158,21 @@ TEST_F(ConstantsTest, Small_3x2x1x1) {
 | 
			
		||||
// TODO(b/29263943): Support tuple constants.
 | 
			
		||||
TEST_F(ConstantsTest, DISABLED_TupleConstant) {
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
  ConstantLiteral(&builder,
 | 
			
		||||
                  *LiteralUtil::MakeTuple(
 | 
			
		||||
                      {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
 | 
			
		||||
                       LiteralUtil::CreateR1<float>({2.0, 42}).get()}));
 | 
			
		||||
  ConstantLiteral(&builder, LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
                                {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
 | 
			
		||||
                                 LiteralUtil::CreateR1<float>({2.0, 42})}));
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> result =
 | 
			
		||||
      ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
 | 
			
		||||
  Literal result = ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  LiteralTestUtil::ExpectR2Near<float>({{1.0}, {2.0}},
 | 
			
		||||
                                       LiteralSlice(*result, {0}), error_spec_);
 | 
			
		||||
  LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, LiteralSlice(*result, {1}),
 | 
			
		||||
                                       LiteralSlice(result, {0}), error_spec_);
 | 
			
		||||
  LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, LiteralSlice(result, {1}),
 | 
			
		||||
                                       error_spec_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ConstantsTest, Token) {
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
  ConstantLiteral(&builder, *LiteralUtil::CreateToken());
 | 
			
		||||
  ConstantLiteral(&builder, LiteralUtil::CreateToken());
 | 
			
		||||
  // TODO(b/80000000): tokens cannot be returned from computations.
 | 
			
		||||
  Tuple(&builder, {});
 | 
			
		||||
  TF_ASSERT_OK(Execute(&builder, {}).status());
 | 
			
		||||
 | 
			
		||||
@ -210,10 +210,10 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) {
 | 
			
		||||
      static_cast<int64>(0x8000008000000000LL),
 | 
			
		||||
      static_cast<int64>(0x8000010000000000LL),
 | 
			
		||||
  };
 | 
			
		||||
  std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int64>({arg});
 | 
			
		||||
  auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
 | 
			
		||||
  Literal arg_literal = LiteralUtil::CreateR1<int64>({arg});
 | 
			
		||||
  auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
 | 
			
		||||
  std::unique_ptr<GlobalData> arg_data =
 | 
			
		||||
      client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(arg_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  ConvertElementType(arg_param, F32);
 | 
			
		||||
 | 
			
		||||
@ -229,10 +229,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) {
 | 
			
		||||
  std::vector<uint32> arg{0,          1,          0x1000,     0x7fffffff,
 | 
			
		||||
                          0x80000000, 0x80000001, 0x80000002, 0x80000003,
 | 
			
		||||
                          0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF};
 | 
			
		||||
  std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
 | 
			
		||||
  auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
 | 
			
		||||
  Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
 | 
			
		||||
  auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
 | 
			
		||||
  std::unique_ptr<GlobalData> arg_data =
 | 
			
		||||
      client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(arg_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  ConvertElementType(arg_param, F32);
 | 
			
		||||
 | 
			
		||||
@ -247,10 +247,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
  std::vector<float> arg{0.0f,        1.0f,          16777216.0f,
 | 
			
		||||
                         16777218.0f, 2147483647.0f, 4294967040.0f};
 | 
			
		||||
  std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
 | 
			
		||||
  auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
 | 
			
		||||
  Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
 | 
			
		||||
  auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
 | 
			
		||||
  std::unique_ptr<GlobalData> arg_data =
 | 
			
		||||
      client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(arg_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  ConvertElementType(arg_param, U32);
 | 
			
		||||
 | 
			
		||||
@ -264,10 +264,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
 | 
			
		||||
XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
  std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF};
 | 
			
		||||
  std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
 | 
			
		||||
  auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
 | 
			
		||||
  Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
 | 
			
		||||
  auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
 | 
			
		||||
  std::unique_ptr<GlobalData> arg_data =
 | 
			
		||||
      client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(arg_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  ConvertElementType(arg_param, S64);
 | 
			
		||||
 | 
			
		||||
@ -281,10 +281,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
 | 
			
		||||
XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) {
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
  std::vector<int32> arg{0, 1, 0x1000, -1, -0x1000};
 | 
			
		||||
  std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int32>({arg});
 | 
			
		||||
  auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
 | 
			
		||||
  Literal arg_literal = LiteralUtil::CreateR1<int32>({arg});
 | 
			
		||||
  auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
 | 
			
		||||
  std::unique_ptr<GlobalData> arg_data =
 | 
			
		||||
      client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(arg_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  ConvertElementType(arg_param, S64);
 | 
			
		||||
 | 
			
		||||
@ -318,10 +318,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) {
 | 
			
		||||
                         9223370937343148032.f,
 | 
			
		||||
                         -9223371487098961920.f,
 | 
			
		||||
                         -9223370937343148032.f};
 | 
			
		||||
  std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
 | 
			
		||||
  auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
 | 
			
		||||
  Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
 | 
			
		||||
  auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
 | 
			
		||||
  std::unique_ptr<GlobalData> arg_data =
 | 
			
		||||
      client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(arg_literal).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  ConvertElementType(arg_param, S64);
 | 
			
		||||
 | 
			
		||||
@ -456,7 +456,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) {
 | 
			
		||||
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      std::unique_ptr<GlobalData> dot_lhs_handle,
 | 
			
		||||
      client_->TransferToServer(*LiteralUtil::CreateR1<half>(input)));
 | 
			
		||||
      client_->TransferToServer(LiteralUtil::CreateR1<half>(input)));
 | 
			
		||||
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
  ConvertElementType(
 | 
			
		||||
@ -476,7 +476,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) {
 | 
			
		||||
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      std::unique_ptr<GlobalData> dot_lhs_handle,
 | 
			
		||||
      client_->TransferToServer(*LiteralUtil::CreateR1<float>(input)));
 | 
			
		||||
      client_->TransferToServer(LiteralUtil::CreateR1<float>(input)));
 | 
			
		||||
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
  ConvertElementType(
 | 
			
		||||
 | 
			
		||||
@ -93,8 +93,7 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest,
 | 
			
		||||
  auto weight_array = absl::make_unique<Array4D<float>>(4, 3, 1, 1);
 | 
			
		||||
  weight_array->FillWithMultiples(0.2);
 | 
			
		||||
  auto weight_data =
 | 
			
		||||
      client_
 | 
			
		||||
          ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array))
 | 
			
		||||
      client_->TransferToServer(LiteralUtil::CreateR4FromArray4D(*weight_array))
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
 | 
			
		||||
@ -123,8 +123,8 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest {
 | 
			
		||||
    }));
 | 
			
		||||
 | 
			
		||||
    ComputeAndCompare(&builder,
 | 
			
		||||
                      {std::move(*LiteralUtil::CreateFromArray(input_data)),
 | 
			
		||||
                       std::move(*LiteralUtil::CreateFromArray(filter_data))},
 | 
			
		||||
                      {LiteralUtil::CreateFromArray(input_data),
 | 
			
		||||
                       LiteralUtil::CreateFromArray(filter_data)},
 | 
			
		||||
                      error_spec_);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
@ -157,8 +157,8 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest {
 | 
			
		||||
        {7.0f, 8.0f},
 | 
			
		||||
    }));
 | 
			
		||||
    ComputeAndCompare(&builder,
 | 
			
		||||
                      {std::move(*LiteralUtil::CreateFromArray(input_data)),
 | 
			
		||||
                       std::move(*LiteralUtil::CreateFromArray(filter_data))},
 | 
			
		||||
                      {LiteralUtil::CreateFromArray(input_data),
 | 
			
		||||
                       LiteralUtil::CreateFromArray(filter_data)},
 | 
			
		||||
                      error_spec_);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
@ -192,8 +192,8 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest {
 | 
			
		||||
    }));
 | 
			
		||||
 | 
			
		||||
    ComputeAndCompare(&builder,
 | 
			
		||||
                      {std::move(*LiteralUtil::CreateFromArray(input_data)),
 | 
			
		||||
                       std::move(*LiteralUtil::CreateFromArray(filter_data))},
 | 
			
		||||
                      {LiteralUtil::CreateFromArray(input_data),
 | 
			
		||||
                       LiteralUtil::CreateFromArray(filter_data)},
 | 
			
		||||
                      error_spec_);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
@ -224,8 +224,8 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest {
 | 
			
		||||
        {{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}}));
 | 
			
		||||
    // clang-format on
 | 
			
		||||
    ComputeAndCompare(&builder,
 | 
			
		||||
                      {std::move(*LiteralUtil::CreateFromArray(input_data)),
 | 
			
		||||
                       std::move(*LiteralUtil::CreateFromArray(filter_data))},
 | 
			
		||||
                      {LiteralUtil::CreateFromArray(input_data),
 | 
			
		||||
                       LiteralUtil::CreateFromArray(filter_data)},
 | 
			
		||||
                      error_spec_);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
@ -249,10 +249,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
 | 
			
		||||
  Array3D<float> expected({{{510, 610, 710, 810}}});
 | 
			
		||||
 | 
			
		||||
  auto input_literal =
 | 
			
		||||
      client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
 | 
			
		||||
      client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
  auto filter_literal =
 | 
			
		||||
      client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
 | 
			
		||||
      client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareR3<float>(&builder, expected,
 | 
			
		||||
@ -284,10 +284,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
 | 
			
		||||
    Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
 | 
			
		||||
 | 
			
		||||
    auto input_literal =
 | 
			
		||||
        client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
 | 
			
		||||
        client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
 | 
			
		||||
            .ConsumeValueOrDie();
 | 
			
		||||
    auto filter_literal =
 | 
			
		||||
        client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
 | 
			
		||||
        client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
 | 
			
		||||
            .ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
    ComputeAndCompareR3<T>(&builder, expected,
 | 
			
		||||
@ -319,10 +319,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
 | 
			
		||||
  Array3D<float> expected({{{190, 320, 230, 380, 270, 440, 310, 500}}});
 | 
			
		||||
 | 
			
		||||
  auto input_literal =
 | 
			
		||||
      client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
 | 
			
		||||
      client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
  auto filter_literal =
 | 
			
		||||
      client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
 | 
			
		||||
      client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareR3<float>(&builder, expected,
 | 
			
		||||
@ -350,10 +350,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
 | 
			
		||||
  Array3D<float> expected({{{510, 0, 610, 0, 710, 0, 810}}});
 | 
			
		||||
 | 
			
		||||
  auto input_literal =
 | 
			
		||||
      client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
 | 
			
		||||
      client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
  auto filter_literal =
 | 
			
		||||
      client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
 | 
			
		||||
      client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareR3<float>(&builder, expected,
 | 
			
		||||
@ -386,10 +386,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
 | 
			
		||||
        {{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
 | 
			
		||||
 | 
			
		||||
    auto input_literal =
 | 
			
		||||
        client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
 | 
			
		||||
        client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
 | 
			
		||||
            .ConsumeValueOrDie();
 | 
			
		||||
    auto filter_literal =
 | 
			
		||||
        client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
 | 
			
		||||
        client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
 | 
			
		||||
            .ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
    ComputeAndCompareR3<T>(&builder, expected,
 | 
			
		||||
@ -435,23 +435,23 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
 | 
			
		||||
  std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
 | 
			
		||||
  iota(input_elems.begin(), input_elems.end(), 1.0f);
 | 
			
		||||
  auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
 | 
			
		||||
  auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
 | 
			
		||||
  auto input_r5 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
 | 
			
		||||
  iota(filter_elems.begin(), filter_elems.end(), 1.0f);
 | 
			
		||||
  auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
 | 
			
		||||
  auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
 | 
			
		||||
  auto filter_r5 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  auto expected_r1 = LiteralUtil::CreateR1<float>(
 | 
			
		||||
      {19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446,
 | 
			
		||||
       38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470});
 | 
			
		||||
  auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie();
 | 
			
		||||
  auto expected_r5 = expected_r1.Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  auto input_literal = client_->TransferToServer(*input_r5).ConsumeValueOrDie();
 | 
			
		||||
  auto input_literal = client_->TransferToServer(input_r5).ConsumeValueOrDie();
 | 
			
		||||
  auto filter_literal =
 | 
			
		||||
      client_->TransferToServer(*filter_r5).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(filter_r5).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareLiteral(&builder, *expected_r5,
 | 
			
		||||
  ComputeAndCompareLiteral(&builder, expected_r5,
 | 
			
		||||
                           {input_literal.get(), filter_literal.get()},
 | 
			
		||||
                           error_spec_);
 | 
			
		||||
}
 | 
			
		||||
@ -498,23 +498,23 @@ class Convolve2D_1x3x3x5_3x3x5x3_Valid : public ConvolutionTest {
 | 
			
		||||
    std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
 | 
			
		||||
    iota_int_init_value(input_elems, 1);
 | 
			
		||||
    auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
 | 
			
		||||
    auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
 | 
			
		||||
    auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
    std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
 | 
			
		||||
    iota_int_init_value(filter_elems, 1);
 | 
			
		||||
    auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
 | 
			
		||||
    auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
 | 
			
		||||
    auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
    auto expected_r1 = LiteralUtil::CreateR1<T>(
 | 
			
		||||
        {static_cast<T>(92115), static_cast<T>(93150), static_cast<T>(94185)});
 | 
			
		||||
    auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
 | 
			
		||||
    auto expected_r4 = expected_r1.Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
    auto input_literal =
 | 
			
		||||
        client_->TransferToServer(*input_r4).ConsumeValueOrDie();
 | 
			
		||||
        client_->TransferToServer(input_r4).ConsumeValueOrDie();
 | 
			
		||||
    auto filter_literal =
 | 
			
		||||
        client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
 | 
			
		||||
        client_->TransferToServer(filter_r4).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
    ComputeAndCompareLiteral(&builder, *expected_r4,
 | 
			
		||||
    ComputeAndCompareLiteral(&builder, expected_r4,
 | 
			
		||||
                             {input_literal.get(), filter_literal.get()},
 | 
			
		||||
                             error_spec_);
 | 
			
		||||
  }
 | 
			
		||||
@ -558,12 +558,12 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest {
 | 
			
		||||
    std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
 | 
			
		||||
    iota_int_init_value(input_elems, 1);
 | 
			
		||||
    auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
 | 
			
		||||
    auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
 | 
			
		||||
    auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
    std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
 | 
			
		||||
    iota_int_init_value(filter_elems, 1);
 | 
			
		||||
    auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
 | 
			
		||||
    auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
 | 
			
		||||
    auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
    auto expected_r1 = LiteralUtil::CreateR1<T>(
 | 
			
		||||
        {static_cast<T>(16029), static_cast<T>(16218), static_cast<T>(16407),
 | 
			
		||||
@ -571,14 +571,14 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest {
 | 
			
		||||
         static_cast<T>(18369), static_cast<T>(18576), static_cast<T>(18783),
 | 
			
		||||
         static_cast<T>(19620), static_cast<T>(19836), static_cast<T>(20052),
 | 
			
		||||
         static_cast<T>(20925), static_cast<T>(21150), static_cast<T>(21375)});
 | 
			
		||||
    auto expected_r4 = expected_r1->Reshape({1, 1, 1, 15}).ConsumeValueOrDie();
 | 
			
		||||
    auto expected_r4 = expected_r1.Reshape({1, 1, 1, 15}).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
    auto input_literal =
 | 
			
		||||
        client_->TransferToServer(*input_r4).ConsumeValueOrDie();
 | 
			
		||||
        client_->TransferToServer(input_r4).ConsumeValueOrDie();
 | 
			
		||||
    auto filter_literal =
 | 
			
		||||
        client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
 | 
			
		||||
        client_->TransferToServer(filter_r4).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
    ComputeAndCompareLiteral(&builder, *expected_r4,
 | 
			
		||||
    ComputeAndCompareLiteral(&builder, expected_r4,
 | 
			
		||||
                             {input_literal.get(), filter_literal.get()},
 | 
			
		||||
                             error_spec_);
 | 
			
		||||
  }
 | 
			
		||||
@ -624,26 +624,26 @@ class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest {
 | 
			
		||||
    std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
 | 
			
		||||
    iota_int_init_value(input_elems, 1);
 | 
			
		||||
    auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
 | 
			
		||||
    auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
 | 
			
		||||
    auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
    std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
 | 
			
		||||
    iota_int_init_value(filter_elems, 1);
 | 
			
		||||
    auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
 | 
			
		||||
    auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
 | 
			
		||||
    auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
    auto expected_r1 = LiteralUtil::CreateR1<T>(
 | 
			
		||||
        {static_cast<T>(5076), static_cast<T>(5160), static_cast<T>(5244),
 | 
			
		||||
         static_cast<T>(5328), static_cast<T>(6164), static_cast<T>(6264),
 | 
			
		||||
         static_cast<T>(6364), static_cast<T>(6464), static_cast<T>(7380),
 | 
			
		||||
         static_cast<T>(7496), static_cast<T>(7612), static_cast<T>(7728)});
 | 
			
		||||
    auto expected_r4 = expected_r1->Reshape({1, 1, 1, 12}).ConsumeValueOrDie();
 | 
			
		||||
    auto expected_r4 = expected_r1.Reshape({1, 1, 1, 12}).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
    auto input_literal =
 | 
			
		||||
        client_->TransferToServer(*input_r4).ConsumeValueOrDie();
 | 
			
		||||
        client_->TransferToServer(input_r4).ConsumeValueOrDie();
 | 
			
		||||
    auto filter_literal =
 | 
			
		||||
        client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
 | 
			
		||||
        client_->TransferToServer(filter_r4).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
    ComputeAndCompareLiteral(&builder, *expected_r4,
 | 
			
		||||
    ComputeAndCompareLiteral(&builder, expected_r4,
 | 
			
		||||
                             {input_literal.get(), filter_literal.get()},
 | 
			
		||||
                             error_spec_);
 | 
			
		||||
  }
 | 
			
		||||
@ -692,8 +692,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
 | 
			
		||||
  expected_result.Fill(0);
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompare(&builder,
 | 
			
		||||
                    {std::move(*LiteralUtil::CreateFromArray(param0)),
 | 
			
		||||
                     std::move(*LiteralUtil::CreateFromArray(param1))},
 | 
			
		||||
                    {LiteralUtil::CreateFromArray(param0),
 | 
			
		||||
                     LiteralUtil::CreateFromArray(param1)},
 | 
			
		||||
                    error_spec_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -749,26 +749,25 @@ class Convolve1D1WindowTestBase
 | 
			
		||||
    std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
 | 
			
		||||
                               static_cast<T>(1.0f));
 | 
			
		||||
    auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
 | 
			
		||||
    auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
 | 
			
		||||
    auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
    std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
 | 
			
		||||
                                static_cast<T>(1.0f));
 | 
			
		||||
 | 
			
		||||
    auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
 | 
			
		||||
    auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
 | 
			
		||||
    auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
    std::vector<T> expect_elems(batch * output_feature * num_windows,
 | 
			
		||||
                                static_cast<T>(window_size * input_feature));
 | 
			
		||||
    auto expected_r1 = LiteralUtil::CreateR1<T>(expect_elems);
 | 
			
		||||
    auto expected_r3 =
 | 
			
		||||
        expected_r1->Reshape({batch, num_windows, output_feature})
 | 
			
		||||
            .ConsumeValueOrDie();
 | 
			
		||||
    auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature})
 | 
			
		||||
                           .ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
    auto input_literal =
 | 
			
		||||
        client_->TransferToServer(*input_r3).ConsumeValueOrDie();
 | 
			
		||||
        client_->TransferToServer(input_r3).ConsumeValueOrDie();
 | 
			
		||||
    auto filter_literal =
 | 
			
		||||
        client_->TransferToServer(*filter_r3).ConsumeValueOrDie();
 | 
			
		||||
    ComputeAndCompareLiteral(&builder, *expected_r3,
 | 
			
		||||
        client_->TransferToServer(filter_r3).ConsumeValueOrDie();
 | 
			
		||||
    ComputeAndCompareLiteral(&builder, expected_r3,
 | 
			
		||||
                             {input_literal.get(), filter_literal.get()},
 | 
			
		||||
                             error_spec_);
 | 
			
		||||
  }
 | 
			
		||||
@ -868,8 +867,8 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
 | 
			
		||||
  }));
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompare(&builder,
 | 
			
		||||
                    {std::move(*LiteralUtil::CreateFromArray(input_data)),
 | 
			
		||||
                     std::move(*LiteralUtil::CreateFromArray(filter_data))},
 | 
			
		||||
                    {LiteralUtil::CreateFromArray(input_data),
 | 
			
		||||
                     LiteralUtil::CreateFromArray(filter_data)},
 | 
			
		||||
                    error_spec_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -891,9 +890,8 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) {
 | 
			
		||||
  Array4D<float> filter_data(1, 1, 1, 2);
 | 
			
		||||
  filter_data.FillIota(10);
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompare(&builder,
 | 
			
		||||
                    {std::move(*LiteralUtil::CreateFromArray(input_data)),
 | 
			
		||||
                     std::move(*LiteralUtil::CreateFromArray(filter_data))});
 | 
			
		||||
  ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data),
 | 
			
		||||
                               LiteralUtil::CreateFromArray(filter_data)});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) {
 | 
			
		||||
@ -928,8 +926,7 @@ XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) {
 | 
			
		||||
              /*padding=*/{{3, 3}, {3, 3}}, /*dimension_numbers=*/dnums,
 | 
			
		||||
              /*feature_group_count=*/64);
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompare(&builder,
 | 
			
		||||
                    {std::move(*LiteralUtil::CreateFromArray(input_data))},
 | 
			
		||||
  ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data)},
 | 
			
		||||
                    error_spec_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1335,23 +1335,23 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) {
 | 
			
		||||
 | 
			
		||||
  auto gradients_flat = LiteralUtil::CreateR1<float>({1});
 | 
			
		||||
  auto gradients_literal =
 | 
			
		||||
      gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
 | 
			
		||||
  auto gradients = ConstantLiteral(&builder, *gradients_literal);
 | 
			
		||||
      gradients_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
 | 
			
		||||
  auto gradients = ConstantLiteral(&builder, gradients_literal);
 | 
			
		||||
 | 
			
		||||
  auto weights_flat = LiteralUtil::CreateR1<float>({1, 10, 100});
 | 
			
		||||
  auto weights_literal =
 | 
			
		||||
      weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
 | 
			
		||||
  auto weights = ConstantLiteral(&builder, *weights_literal);
 | 
			
		||||
      weights_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
 | 
			
		||||
  auto weights = ConstantLiteral(&builder, weights_literal);
 | 
			
		||||
 | 
			
		||||
  auto expected_flat = LiteralUtil::CreateR1<float>({10});
 | 
			
		||||
  auto expected_literal =
 | 
			
		||||
      expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
 | 
			
		||||
      expected_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  auto mirrored_weights = Rev(weights, {2, 3, 4});
 | 
			
		||||
  ConvWithGeneralPadding(gradients, mirrored_weights,
 | 
			
		||||
                         /*window_strides=*/{1, 1, 1},
 | 
			
		||||
                         /*padding=*/{{0, 0}, {0, 0}, {1, 1}});
 | 
			
		||||
  ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_);
 | 
			
		||||
  ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
 | 
			
		||||
@ -1359,17 +1359,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
 | 
			
		||||
 | 
			
		||||
  auto activations_flat = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
 | 
			
		||||
  auto activations_literal =
 | 
			
		||||
      activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie();
 | 
			
		||||
  auto activations = ConstantLiteral(&builder, *activations_literal);
 | 
			
		||||
      activations_flat.Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie();
 | 
			
		||||
  auto activations = ConstantLiteral(&builder, activations_literal);
 | 
			
		||||
 | 
			
		||||
  auto gradients_flat = LiteralUtil::CreateR1<float>({100, 10, 1});
 | 
			
		||||
  auto gradients_literal =
 | 
			
		||||
      gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
 | 
			
		||||
  auto gradients = ConstantLiteral(&builder, *gradients_literal);
 | 
			
		||||
      gradients_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
 | 
			
		||||
  auto gradients = ConstantLiteral(&builder, gradients_literal);
 | 
			
		||||
 | 
			
		||||
  auto expected_flat = LiteralUtil::CreateR1<float>({13, 24, 130});
 | 
			
		||||
  auto expected_literal =
 | 
			
		||||
      expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
 | 
			
		||||
      expected_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  auto forward_conv =
 | 
			
		||||
      ConvGeneralDilated(activations, gradients,
 | 
			
		||||
@ -1379,7 +1379,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
 | 
			
		||||
                         XlaBuilder::CreateDefaultConvDimensionNumbers(
 | 
			
		||||
                             /*num_spatial_dims=*/3));
 | 
			
		||||
  Transpose(forward_conv, {0, 1, 2, 3, 4});
 | 
			
		||||
  ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_);
 | 
			
		||||
  ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
@ -40,16 +40,16 @@ class CopyOpTest : public HloTestBase {
 | 
			
		||||
 protected:
 | 
			
		||||
  void TestCopyOp(const Literal& literal) {
 | 
			
		||||
    auto builder = HloComputation::Builder(TestName());
 | 
			
		||||
    auto constant = builder.AddInstruction(
 | 
			
		||||
        HloInstruction::CreateConstant(literal.CloneToUnique()));
 | 
			
		||||
    auto constant =
 | 
			
		||||
        builder.AddInstruction(HloInstruction::CreateConstant(literal.Clone()));
 | 
			
		||||
    builder.AddInstruction(HloInstruction::CreateUnary(
 | 
			
		||||
        constant->shape(), HloOpcode::kCopy, constant));
 | 
			
		||||
    auto computation = builder.Build();
 | 
			
		||||
    auto module = CreateNewModule();
 | 
			
		||||
    module->AddEntryComputation(std::move(computation));
 | 
			
		||||
 | 
			
		||||
    std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
 | 
			
		||||
    EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result));
 | 
			
		||||
    Literal result = ExecuteAndTransfer(std::move(module), {});
 | 
			
		||||
    EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3);
 | 
			
		||||
@ -58,31 +58,30 @@ class CopyOpTest : public HloTestBase {
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(CopyOpTest, CopyR0Bool) {
 | 
			
		||||
  TestCopyOp(*LiteralUtil::CreateR0<bool>(true));
 | 
			
		||||
  TestCopyOp(LiteralUtil::CreateR0<bool>(true));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(CopyOpTest, CopyR1S0U32) {
 | 
			
		||||
  TestCopyOp(*LiteralUtil::CreateR1<uint32>({}));
 | 
			
		||||
  TestCopyOp(LiteralUtil::CreateR1<uint32>({}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(CopyOpTest, CopyR1S3U32) {
 | 
			
		||||
  TestCopyOp(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
 | 
			
		||||
  TestCopyOp(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) {
 | 
			
		||||
  TestCopyOp(
 | 
			
		||||
      *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
 | 
			
		||||
                              {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
 | 
			
		||||
  TestCopyOp(LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
 | 
			
		||||
                                    {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) {
 | 
			
		||||
  TestCopyOp(*LiteralUtil::CreateR4(
 | 
			
		||||
  TestCopyOp(LiteralUtil::CreateR4(
 | 
			
		||||
      {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
 | 
			
		||||
       {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) {
 | 
			
		||||
  TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
 | 
			
		||||
  TestCopyOp(LiteralUtil::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
 | 
			
		||||
@ -90,7 +89,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
 | 
			
		||||
 | 
			
		||||
  // Copy literal to device to use as parameter.
 | 
			
		||||
  auto literal = LiteralUtil::CreateR0<float>(42.0);
 | 
			
		||||
  Shape shape = literal->shape();
 | 
			
		||||
  Shape shape = literal.shape();
 | 
			
		||||
 | 
			
		||||
  auto param0 = builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateParameter(0, shape, "param0"));
 | 
			
		||||
@ -102,9 +101,8 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
 | 
			
		||||
  auto module = CreateNewModule();
 | 
			
		||||
  module->AddEntryComputation(std::move(computation));
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> result =
 | 
			
		||||
      ExecuteAndTransfer(std::move(module), {literal.get()});
 | 
			
		||||
  LiteralTestUtil::ExpectR0Near<float>(42.0f, *result, error_spec_);
 | 
			
		||||
  Literal result = ExecuteAndTransfer(std::move(module), {&literal});
 | 
			
		||||
  LiteralTestUtil::ExpectR0Near<float>(42.0f, result, error_spec_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) {
 | 
			
		||||
@ -123,19 +121,17 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) {
 | 
			
		||||
 | 
			
		||||
  auto module = CreateNewModule();
 | 
			
		||||
  module->AddEntryComputation(std::move(computation));
 | 
			
		||||
  std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
 | 
			
		||||
  LiteralTestUtil::ExpectR2Near<float>({{1.0, 2.0}, {3.0, 4.0}}, *result,
 | 
			
		||||
  Literal result = ExecuteAndTransfer(std::move(module), {});
 | 
			
		||||
  LiteralTestUtil::ExpectR2Near<float>({{1.0, 2.0}, {3.0, 4.0}}, result,
 | 
			
		||||
                                       error_spec_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
 | 
			
		||||
  HloComputation::Builder builder(TestName());
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal =
 | 
			
		||||
      LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
 | 
			
		||||
  Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
 | 
			
		||||
  // Reverse the minor-to-major order of the literal.
 | 
			
		||||
  Layout* literal_layout =
 | 
			
		||||
      literal->mutable_shape_do_not_use()->mutable_layout();
 | 
			
		||||
  Layout* literal_layout = literal.mutable_shape_do_not_use()->mutable_layout();
 | 
			
		||||
  ASSERT_EQ(2, literal_layout->minor_to_major_size());
 | 
			
		||||
  literal_layout->mutable_minor_to_major()->SwapElements(0, 1);
 | 
			
		||||
 | 
			
		||||
@ -149,11 +145,11 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
 | 
			
		||||
 | 
			
		||||
  auto module = CreateNewModule();
 | 
			
		||||
  module->AddEntryComputation(std::move(computation));
 | 
			
		||||
  std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
 | 
			
		||||
  Literal result = ExecuteAndTransfer(std::move(module), {});
 | 
			
		||||
 | 
			
		||||
  // The result of the computation has the default layout, which is the inverse
 | 
			
		||||
  // of the layout of the source literal.
 | 
			
		||||
  LiteralTestUtil::ExpectR2Near<float>({{1.0, 3.0}, {2.0, 4.0}}, *result,
 | 
			
		||||
  LiteralTestUtil::ExpectR2Near<float>({{1.0, 3.0}, {2.0, 4.0}}, result,
 | 
			
		||||
                                       error_spec_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -169,7 +165,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
 | 
			
		||||
 | 
			
		||||
  HloComputation::Builder builder(TestName());
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(a);
 | 
			
		||||
  Literal literal = LiteralUtil::CreateR3FromArray3D(a);
 | 
			
		||||
 | 
			
		||||
  HloInstruction* constant = builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateConstant(std::move(literal)));
 | 
			
		||||
@ -182,9 +178,9 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
 | 
			
		||||
  auto module = CreateNewModule();
 | 
			
		||||
  module->AddEntryComputation(std::move(computation));
 | 
			
		||||
  ForceResultLayout(module.get(), LayoutUtil::MakeLayout({1, 2, 0}));
 | 
			
		||||
  std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
 | 
			
		||||
  Literal result = ExecuteAndTransfer(std::move(module), {});
 | 
			
		||||
 | 
			
		||||
  LiteralTestUtil::ExpectR3EqualArray3D(a, *result);
 | 
			
		||||
  LiteralTestUtil::ExpectR3EqualArray3D(a, result);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
 | 
			
		||||
@ -203,7 +199,7 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
 | 
			
		||||
 | 
			
		||||
  HloComputation::Builder builder(TestName());
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal = LiteralUtil::CreateR4FromArray4D(a);
 | 
			
		||||
  Literal literal = LiteralUtil::CreateR4FromArray4D(a);
 | 
			
		||||
 | 
			
		||||
  HloInstruction* constant = builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateConstant(std::move(literal)));
 | 
			
		||||
@ -216,9 +212,9 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
 | 
			
		||||
  auto module = CreateNewModule();
 | 
			
		||||
  module->AddEntryComputation(std::move(computation));
 | 
			
		||||
  ForceResultLayout(module.get(), LayoutUtil::MakeLayout(permutation));
 | 
			
		||||
  std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
 | 
			
		||||
  Literal result = ExecuteAndTransfer(std::move(module), {});
 | 
			
		||||
 | 
			
		||||
  LiteralTestUtil::ExpectR4EqualArray4D(a, *result);
 | 
			
		||||
  LiteralTestUtil::ExpectR4EqualArray4D(a, result);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_SingleIncompleteTilePerLayer) {
 | 
			
		||||
@ -250,11 +246,11 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) {
 | 
			
		||||
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
  Parameter(&builder, 0, in_shape, "input");
 | 
			
		||||
  auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie();
 | 
			
		||||
  auto input_data = client_->TransferToServer(empty).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape)
 | 
			
		||||
                    .ConsumeValueOrDie();
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual));
 | 
			
		||||
  EXPECT_TRUE(LiteralTestUtil::Equal(empty, actual));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
@ -46,7 +46,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) {
 | 
			
		||||
  auto module =
 | 
			
		||||
      ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
 | 
			
		||||
  auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
 | 
			
		||||
  EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()}));
 | 
			
		||||
  EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
 | 
			
		||||
@ -68,9 +68,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
 | 
			
		||||
      ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
 | 
			
		||||
  auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
 | 
			
		||||
  auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
 | 
			
		||||
  EXPECT_EQ(
 | 
			
		||||
      *LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
 | 
			
		||||
      *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()}));
 | 
			
		||||
  EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}),
 | 
			
		||||
            ExecuteAndTransfer(std::move(module), {&literal0, &literal1}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// On the GPU backend, constants get special handling.  Someone might pass a
 | 
			
		||||
@ -95,8 +94,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) {
 | 
			
		||||
      ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
 | 
			
		||||
  auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
 | 
			
		||||
  auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
 | 
			
		||||
  EXPECT_EQ(*LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
 | 
			
		||||
            *ExecuteAndTransfer(std::move(module), {literal0.get()}));
 | 
			
		||||
  EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}),
 | 
			
		||||
            ExecuteAndTransfer(std::move(module), {&literal0}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
@ -80,8 +80,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) {
 | 
			
		||||
 | 
			
		||||
  module->AddEntryComputation(builder.Build());
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
 | 
			
		||||
  LiteralTestUtil::ExpectR0Near<float>(44.0f, *result, error_spec_);
 | 
			
		||||
  Literal result = ExecuteAndTransfer(std::move(module), {});
 | 
			
		||||
  LiteralTestUtil::ExpectR0Near<float>(44.0f, result, error_spec_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
 | 
			
		||||
@ -101,8 +101,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
 | 
			
		||||
 | 
			
		||||
  module->AddEntryComputation(builder.Build());
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
 | 
			
		||||
  LiteralTestUtil::ExpectR0Near<float>(10.0f, *result, error_spec_);
 | 
			
		||||
  Literal result = ExecuteAndTransfer(std::move(module), {});
 | 
			
		||||
  LiteralTestUtil::ExpectR0Near<float>(10.0f, result, error_spec_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(CustomCallTest,
 | 
			
		||||
@ -125,9 +125,9 @@ XLA_TEST_F(CustomCallTest,
 | 
			
		||||
 | 
			
		||||
  module->AddEntryComputation(b.Build());
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
 | 
			
		||||
  Literal result = ExecuteAndTransfer(std::move(module), {});
 | 
			
		||||
  LiteralTestUtil::ExpectR3EqualArray3D<float>(
 | 
			
		||||
      Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, *result);
 | 
			
		||||
      Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
class CustomCallClientAPITest : public ClientLibraryTestBase {};
 | 
			
		||||
 | 
			
		||||
@ -64,11 +64,11 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) {
 | 
			
		||||
 | 
			
		||||
  // Try copying the elements back and comparing it
 | 
			
		||||
  auto handles = result_status.ConsumeValueOrDie();
 | 
			
		||||
  std::unique_ptr<Literal> literal;
 | 
			
		||||
  Literal literal;
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(DeconstructTupleTest, DeconstructTupleTwice) {
 | 
			
		||||
@ -86,19 +86,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) {
 | 
			
		||||
  auto handles1 = result_status1.ConsumeValueOrDie();
 | 
			
		||||
  auto handles2 = result_status2.ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal;
 | 
			
		||||
  Literal literal;
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[0]));
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[1]));
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
 | 
			
		||||
 | 
			
		||||
  handles1[0].reset();
 | 
			
		||||
  handles1[1].reset();
 | 
			
		||||
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[0]));
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[1]));
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) {
 | 
			
		||||
@ -116,15 +116,15 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) {
 | 
			
		||||
  // the same as handle[3] and handle[1] should be the same as handle[2].
 | 
			
		||||
  auto handles = result_status.ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal;
 | 
			
		||||
  Literal literal;
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[3]));
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) {
 | 
			
		||||
@ -142,19 +142,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) {
 | 
			
		||||
  // should not have been deallocated because of reference counting.
 | 
			
		||||
  global_data.reset();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Literal> literal;
 | 
			
		||||
  Literal literal;
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
 | 
			
		||||
 | 
			
		||||
  /// Try deallocating one of the repeated elements, then copy
 | 
			
		||||
  handles[0].reset();
 | 
			
		||||
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
 | 
			
		||||
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(DeconstructTupleTest, DeconstructNonTuple) {
 | 
			
		||||
@ -170,10 +170,9 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) {
 | 
			
		||||
 | 
			
		||||
XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) {
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
  std::unique_ptr<Literal> param0_literal =
 | 
			
		||||
      LiteralUtil::CreateR1<float>({3.14f, -100.25f});
 | 
			
		||||
  Literal param0_literal = LiteralUtil::CreateR1<float>({3.14f, -100.25f});
 | 
			
		||||
  std::unique_ptr<GlobalData> param0_data =
 | 
			
		||||
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 | 
			
		||||
  auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0");
 | 
			
		||||
  Tuple(&builder, {p});
 | 
			
		||||
  auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()});
 | 
			
		||||
 | 
			
		||||
@ -68,16 +68,16 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) {
 | 
			
		||||
  XlaOp param;
 | 
			
		||||
  auto param_data = CreateParameterAndTransferLiteral(
 | 
			
		||||
      0,
 | 
			
		||||
      *LiteralUtil::MakeTuple(
 | 
			
		||||
          {LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}).get(),
 | 
			
		||||
           LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}}).get()}),
 | 
			
		||||
      LiteralUtil::MakeTupleFromSlices(
 | 
			
		||||
          {LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}),
 | 
			
		||||
           LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}})}),
 | 
			
		||||
      "arg0", &builder, ¶m);
 | 
			
		||||
  auto lhs = GetTupleElement(param, 0);
 | 
			
		||||
  auto rhs = GetTupleElement(param, 1);
 | 
			
		||||
  Dot(lhs, rhs);
 | 
			
		||||
 | 
			
		||||
  ComputeAndCompareLiteral(&builder,
 | 
			
		||||
                           *LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
 | 
			
		||||
                           LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
 | 
			
		||||
                           {param_data.get()});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -196,11 +196,11 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) {
 | 
			
		||||
 | 
			
		||||
  auto lhs_handle =
 | 
			
		||||
      this->client_
 | 
			
		||||
          ->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
 | 
			
		||||
          ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
 | 
			
		||||
              {{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}}))
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
  auto rhs_handle = this->client_
 | 
			
		||||
                        ->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
 | 
			
		||||
                        ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
 | 
			
		||||
                            {{1.0f}, {2.0f}, {3.0f}, {4.0f}}))
 | 
			
		||||
                        .ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
@ -219,14 +219,14 @@ class SquareMatrixDot : public DotOperationTest {
 | 
			
		||||
  void TestImpl(bool lhs_row_major, bool rhs_row_major) {
 | 
			
		||||
    auto lhs_handle =
 | 
			
		||||
        client_
 | 
			
		||||
            ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
 | 
			
		||||
            ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
 | 
			
		||||
                {{1.0f, 2.0f}, {3.0f, -4.0f}},
 | 
			
		||||
                LayoutUtil::MakeLayout(
 | 
			
		||||
                    MinorToMajorForIsRowMajor(lhs_row_major))))
 | 
			
		||||
            .ConsumeValueOrDie();
 | 
			
		||||
    auto rhs_handle =
 | 
			
		||||
        client_
 | 
			
		||||
            ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
 | 
			
		||||
            ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
 | 
			
		||||
                {{1.0f, 6.0f}, {7.0f, -4.0f}},
 | 
			
		||||
                LayoutUtil::MakeLayout(
 | 
			
		||||
                    MinorToMajorForIsRowMajor(rhs_row_major))))
 | 
			
		||||
@ -286,24 +286,23 @@ void ParametricDotTest::TestImpl() {
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Array2D<NativeT>> dot_lhs_data =
 | 
			
		||||
      MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.k);
 | 
			
		||||
  std::unique_ptr<Literal> dot_lhs_lit =
 | 
			
		||||
      LiteralUtil::CreateR2FromArray2DWithLayout(
 | 
			
		||||
          *dot_lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(
 | 
			
		||||
                             param.dot_lhs_row_major)));
 | 
			
		||||
  Literal dot_lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
 | 
			
		||||
      *dot_lhs_data, LayoutUtil::MakeLayout(
 | 
			
		||||
                         MinorToMajorForIsRowMajor(param.dot_lhs_row_major)));
 | 
			
		||||
  std::unique_ptr<GlobalData> dot_lhs_handle =
 | 
			
		||||
      client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(dot_lhs_lit).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Array2D<NativeT>> dot_rhs_data =
 | 
			
		||||
      MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.k, param.n);
 | 
			
		||||
  Layout rhs_layout = LayoutUtil::MakeLayout(
 | 
			
		||||
      MinorToMajorForIsRowMajor(param.dot_rhs_row_major));
 | 
			
		||||
  std::unique_ptr<Literal> dot_rhs_lit =
 | 
			
		||||
  Literal dot_rhs_lit =
 | 
			
		||||
      LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout);
 | 
			
		||||
  std::unique_ptr<GlobalData> dot_rhs_handle =
 | 
			
		||||
      client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie();
 | 
			
		||||
      client_->TransferToServer(dot_rhs_lit).ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<Array2D<NativeT>> addend_data;
 | 
			
		||||
  std::unique_ptr<Literal> addend_lit;
 | 
			
		||||
  Literal addend_lit;
 | 
			
		||||
  std::unique_ptr<GlobalData> addend_handle;
 | 
			
		||||
 | 
			
		||||
  if (param.has_addend) {
 | 
			
		||||
@ -311,7 +310,7 @@ void ParametricDotTest::TestImpl() {
 | 
			
		||||
    addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
 | 
			
		||||
        *addend_data, LayoutUtil::MakeLayout(
 | 
			
		||||
                          MinorToMajorForIsRowMajor(param.addend_row_major)));
 | 
			
		||||
    addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie();
 | 
			
		||||
    addend_handle = client_->TransferToServer(addend_lit).ConsumeValueOrDie();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  XlaBuilder builder(TestName());
 | 
			
		||||
@ -477,14 +476,14 @@ class NonsquareMatrixDot : public DotOperationTest {
 | 
			
		||||
  void TestImpl(bool lhs_row_major, bool rhs_row_major) {
 | 
			
		||||
    auto lhs_handle =
 | 
			
		||||
        client_
 | 
			
		||||
            ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
 | 
			
		||||
            ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
 | 
			
		||||
                {{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}},
 | 
			
		||||
                LayoutUtil::MakeLayout(
 | 
			
		||||
                    MinorToMajorForIsRowMajor(lhs_row_major))))
 | 
			
		||||
            .ConsumeValueOrDie();
 | 
			
		||||
    auto rhs_handle =
 | 
			
		||||
        client_
 | 
			
		||||
            ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
 | 
			
		||||
            ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
 | 
			
		||||
                {{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}},
 | 
			
		||||
                LayoutUtil::MakeLayout(
 | 
			
		||||
                    MinorToMajorForIsRowMajor(rhs_row_major))))
 | 
			
		||||
@ -511,12 +510,12 @@ XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); }
 | 
			
		||||
XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
 | 
			
		||||
  auto lhs_handle =
 | 
			
		||||
      client_
 | 
			
		||||
          ->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
 | 
			
		||||
          ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
 | 
			
		||||
              {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
  auto rhs_handle =
 | 
			
		||||
      client_
 | 
			
		||||
          ->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
 | 
			
		||||
          ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
 | 
			
		||||
              {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
 | 
			
		||||
              LayoutUtil::MakeLayout({1, 0})))
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
@ -584,7 +583,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
 | 
			
		||||
  Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2});
 | 
			
		||||
 | 
			
		||||
  auto x_data = this->client_
 | 
			
		||||
                    ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
 | 
			
		||||
                    ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
 | 
			
		||||
                        {{{{1000.0f, 100.0f}, {10.0f, 1.0f}},
 | 
			
		||||
                          {{2000.0f, 200.0f}, {20.0f, 2.0f}}},
 | 
			
		||||
                         {{{3000.0f, 300.0f}, {30.0f, 3.0f}},
 | 
			
		||||
@ -592,7 +591,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
 | 
			
		||||
                    .ConsumeValueOrDie();
 | 
			
		||||
  auto y_data =
 | 
			
		||||
      this->client_
 | 
			
		||||
          ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
 | 
			
		||||
          ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
 | 
			
		||||
              {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
 | 
			
		||||
               {{{11.0f, 22.0f}, {33.0f, 44.0f}},
 | 
			
		||||
                {{55.0f, 66.0f}, {77.0f, 88.0f}}}}))
 | 
			
		||||
@ -630,13 +629,13 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) {
 | 
			
		||||
 | 
			
		||||
  auto x_data =
 | 
			
		||||
      this->client_
 | 
			
		||||
          ->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
 | 
			
		||||
          ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
 | 
			
		||||
              {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
  auto y_data =
 | 
			
		||||
      this->client_
 | 
			
		||||
          ->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
 | 
			
		||||
          ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
 | 
			
		||||
              {{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}))
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
 | 
			
		||||
@ -668,7 +667,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
 | 
			
		||||
 | 
			
		||||
  auto x_data =
 | 
			
		||||
      this->client_
 | 
			
		||||
          ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
 | 
			
		||||
          ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
 | 
			
		||||
              {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
 | 
			
		||||
               {{{9.0f, 10.0f}, {11.0f, 12.0f}},
 | 
			
		||||
                {{13.0f, 14.0f}, {15.0f, 16.0f}}}}))
 | 
			
		||||
@ -676,7 +675,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
 | 
			
		||||
 | 
			
		||||
  auto y_data =
 | 
			
		||||
      this->client_
 | 
			
		||||
          ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
 | 
			
		||||
          ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
 | 
			
		||||
              {{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}},
 | 
			
		||||
               {{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}}))
 | 
			
		||||
          .ConsumeValueOrDie();
 | 
			
		||||
@ -708,14 +707,14 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) {
 | 
			
		||||
        auto lhs_handle =
 | 
			
		||||
            this->client_
 | 
			
		||||
                ->TransferToServer(
 | 
			
		||||
                    *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
 | 
			
		||||
                    LiteralUtil::CreateR2FromArray2DWithLayout<T>(
 | 
			
		||||
                        *lhs, LayoutUtil::MakeLayout(
 | 
			
		||||
                                  MinorToMajorForIsRowMajor(row_major))))
 | 
			
		||||
                .ConsumeValueOrDie();
 | 
			
		||||
        auto rhs_handle =
 | 
			
		||||
            this->client_
 | 
			
		||||
                ->TransferToServer(
 | 
			
		||||
                    *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
 | 
			
		||||
                    LiteralUtil::CreateR2FromArray2DWithLayout<T>(
 | 
			
		||||
                        *rhs, LayoutUtil::MakeLayout(
 | 
			
		||||
                                  MinorToMajorForIsRowMajor(row_major))))
 | 
			
		||||
                .ConsumeValueOrDie();
 | 
			
		||||
@ -778,15 +777,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      auto arg_0_value,
 | 
			
		||||
      this->client_->TransferToServer(
 | 
			
		||||
          *LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
 | 
			
		||||
          LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      auto arg_1_value,
 | 
			
		||||
      this->client_->TransferToServer(
 | 
			
		||||
          *LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
 | 
			
		||||
          LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      auto arg_2_value,
 | 
			
		||||
      this->client_->TransferToServer(
 | 
			
		||||
          *LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
 | 
			
		||||
          LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
 | 
			
		||||
 | 
			
		||||
  Array2D<T> expected({{53.0f, 74.0f}, {45.0f, 66.0f}});
 | 
			
		||||
  this->template ComputeAndCompareR2<T>(
 | 
			
		||||
@ -827,15 +826,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      auto arg_0_value,
 | 
			
		||||
      this->client_->TransferToServer(
 | 
			
		||||
          *LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
 | 
			
		||||
          LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      auto arg_1_value,
 | 
			
		||||
      this->client_->TransferToServer(
 | 
			
		||||
          *LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
 | 
			
		||||
          LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      auto arg_2_value,
 | 
			
		||||
      this->client_->TransferToServer(
 | 
			
		||||
          *LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
 | 
			
		||||
          LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
 | 
			
		||||
 | 
			
		||||
  Array2D<T> expected({{38.0f, 36.0f}, {93.0f, 91.0f}});
 | 
			
		||||
  this->template ComputeAndCompareR2<T>(
 | 
			
		||||
 | 
			
		||||
@ -124,13 +124,13 @@ class DynamicSliceTest : public ClientLibraryTestBase {
 | 
			
		||||
    // vector<bool> is special so that it cannot be a Span<bool>, which
 | 
			
		||||
    // is what the code below wants. So instead we do this.
 | 
			
		||||
    Literal input_values =
 | 
			
		||||
        std::move(*LiteralUtil::CreateR1(input_values_int)
 | 
			
		||||
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                       .ValueOrDie());
 | 
			
		||||
        LiteralUtil::CreateR1(input_values_int)
 | 
			
		||||
            .Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
            .ValueOrDie();
 | 
			
		||||
    Literal expected_values =
 | 
			
		||||
        std::move(*LiteralUtil::CreateR1(expected_values_int)
 | 
			
		||||
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                       .ValueOrDie());
 | 
			
		||||
        std::move(LiteralUtil::CreateR1(expected_values_int)
 | 
			
		||||
                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                      .ValueOrDie());
 | 
			
		||||
 | 
			
		||||
    XlaBuilder builder(TestName());
 | 
			
		||||
    // Initialize and transfer dynamic slice start indices parameter.
 | 
			
		||||
@ -150,13 +150,13 @@ class DynamicSliceTest : public ClientLibraryTestBase {
 | 
			
		||||
             const std::vector<int64>& slice_sizes,
 | 
			
		||||
             const Array2D<int>& expected_values_int) {
 | 
			
		||||
    Literal input_values =
 | 
			
		||||
        std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
 | 
			
		||||
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                       .ValueOrDie());
 | 
			
		||||
        std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
 | 
			
		||||
                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                      .ValueOrDie());
 | 
			
		||||
    Literal expected_values =
 | 
			
		||||
        std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
 | 
			
		||||
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                       .ValueOrDie());
 | 
			
		||||
        std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
 | 
			
		||||
                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                      .ValueOrDie());
 | 
			
		||||
 | 
			
		||||
    XlaBuilder builder(TestName());
 | 
			
		||||
    // Initialize and transfer dynamic slice start indices parameter.
 | 
			
		||||
@ -176,13 +176,13 @@ class DynamicSliceTest : public ClientLibraryTestBase {
 | 
			
		||||
             const std::vector<int64>& slice_sizes,
 | 
			
		||||
             const Array3D<int>& expected_values_int) {
 | 
			
		||||
    Literal input_values =
 | 
			
		||||
        std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
 | 
			
		||||
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                       .ValueOrDie());
 | 
			
		||||
        std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
 | 
			
		||||
                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                      .ValueOrDie());
 | 
			
		||||
    Literal expected_values =
 | 
			
		||||
        std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
 | 
			
		||||
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                       .ValueOrDie());
 | 
			
		||||
        std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
 | 
			
		||||
                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                      .ValueOrDie());
 | 
			
		||||
 | 
			
		||||
    XlaBuilder builder(TestName());
 | 
			
		||||
    // Initialize and transfer dynamic slice start indices parameter.
 | 
			
		||||
@ -359,17 +359,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
 | 
			
		||||
  void RunR0(int input_value_int, int update_value_int,
 | 
			
		||||
             const std::vector<IndexT> slice_starts, int expected_value_int) {
 | 
			
		||||
    Literal input_value =
 | 
			
		||||
        std::move(*LiteralUtil::CreateR0(input_value_int)
 | 
			
		||||
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                       .ValueOrDie());
 | 
			
		||||
        std::move(LiteralUtil::CreateR0(input_value_int)
 | 
			
		||||
                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                      .ValueOrDie());
 | 
			
		||||
    Literal update_value =
 | 
			
		||||
        std::move(*LiteralUtil::CreateR0(update_value_int)
 | 
			
		||||
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                       .ValueOrDie());
 | 
			
		||||
        std::move(LiteralUtil::CreateR0(update_value_int)
 | 
			
		||||
                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                      .ValueOrDie());
 | 
			
		||||
    Literal expected_value =
 | 
			
		||||
        std::move(*LiteralUtil::CreateR0(expected_value_int)
 | 
			
		||||
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                       .ValueOrDie());
 | 
			
		||||
        std::move(LiteralUtil::CreateR0(expected_value_int)
 | 
			
		||||
                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                      .ValueOrDie());
 | 
			
		||||
 | 
			
		||||
    XlaBuilder builder(TestName());
 | 
			
		||||
    // Initialize and transfer dynamic slice start indices parameter.
 | 
			
		||||
@ -390,17 +390,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
 | 
			
		||||
             const std::vector<IndexT> slice_starts,
 | 
			
		||||
             absl::Span<const int> expected_values_int) {
 | 
			
		||||
    Literal input_values =
 | 
			
		||||
        std::move(*LiteralUtil::CreateR1(input_values_int)
 | 
			
		||||
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                       .ValueOrDie());
 | 
			
		||||
        std::move(LiteralUtil::CreateR1(input_values_int)
 | 
			
		||||
                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                      .ValueOrDie());
 | 
			
		||||
    Literal update_values =
 | 
			
		||||
        std::move(*LiteralUtil::CreateR1(update_values_int)
 | 
			
		||||
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                       .ValueOrDie());
 | 
			
		||||
        std::move(LiteralUtil::CreateR1(update_values_int)
 | 
			
		||||
                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                      .ValueOrDie());
 | 
			
		||||
    Literal expected_values =
 | 
			
		||||
        std::move(*LiteralUtil::CreateR1(expected_values_int)
 | 
			
		||||
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                       .ValueOrDie());
 | 
			
		||||
        std::move(LiteralUtil::CreateR1(expected_values_int)
 | 
			
		||||
                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                      .ValueOrDie());
 | 
			
		||||
 | 
			
		||||
    XlaBuilder builder(TestName());
 | 
			
		||||
    // Initialize and transfer dynamic slice start indices parameter.
 | 
			
		||||
@ -421,17 +421,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
 | 
			
		||||
             const std::vector<IndexT> slice_starts,
 | 
			
		||||
             const Array2D<int>& expected_values_int) {
 | 
			
		||||
    Literal input_values =
 | 
			
		||||
        std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
 | 
			
		||||
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                       .ValueOrDie());
 | 
			
		||||
        std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
 | 
			
		||||
                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                      .ValueOrDie());
 | 
			
		||||
    Literal update_values =
 | 
			
		||||
        std::move(*LiteralUtil::CreateR2FromArray2D(update_values_int)
 | 
			
		||||
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                       .ValueOrDie());
 | 
			
		||||
        std::move(LiteralUtil::CreateR2FromArray2D(update_values_int)
 | 
			
		||||
                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                      .ValueOrDie());
 | 
			
		||||
    Literal expected_values =
 | 
			
		||||
        std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
 | 
			
		||||
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                       .ValueOrDie());
 | 
			
		||||
        std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
 | 
			
		||||
                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                      .ValueOrDie());
 | 
			
		||||
 | 
			
		||||
    XlaBuilder builder(TestName());
 | 
			
		||||
    // Initialize and transfer dynamic slice start indices parameter.
 | 
			
		||||
@ -452,17 +452,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
 | 
			
		||||
             const std::vector<IndexT> slice_starts,
 | 
			
		||||
             const Array3D<int>& expected_values_int) {
 | 
			
		||||
    Literal input_values =
 | 
			
		||||
        std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
 | 
			
		||||
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                       .ValueOrDie());
 | 
			
		||||
        std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
 | 
			
		||||
                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                      .ValueOrDie());
 | 
			
		||||
    Literal update_values =
 | 
			
		||||
        std::move(*LiteralUtil::CreateR3FromArray3D(update_values_int)
 | 
			
		||||
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                       .ValueOrDie());
 | 
			
		||||
        std::move(LiteralUtil::CreateR3FromArray3D(update_values_int)
 | 
			
		||||
                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                      .ValueOrDie());
 | 
			
		||||
    Literal expected_values =
 | 
			
		||||
        std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
 | 
			
		||||
                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                       .ValueOrDie());
 | 
			
		||||
        std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
 | 
			
		||||
                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
 | 
			
		||||
                      .ValueOrDie());
 | 
			
		||||
 | 
			
		||||
    XlaBuilder builder(TestName());
 | 
			
		||||
    // Initialize and transfer dynamic slice start indices parameter.
 | 
			
		||||
@ -529,9 +529,8 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
 | 
			
		||||
 | 
			
		||||
  template <typename NativeT>
 | 
			
		||||
  void DumpArray(const string& name, const Array3D<NativeT> values) {
 | 
			
		||||
    std::unique_ptr<Literal> literal =
 | 
			
		||||
        LiteralUtil::CreateR3FromArray3D<NativeT>(values);
 | 
			
		||||
    LOG(INFO) << name << ":" << literal->ToString();
 | 
			
		||||
    Literal literal = LiteralUtil::CreateR3FromArray3D<NativeT>(values);
 | 
			
		||||
    LOG(INFO) << name << ":" << literal.ToString();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@ -719,7 +718,7 @@ void BM_DynamicSlice(int num_iters) {
 | 
			
		||||
  auto input_literal = LiteralUtil::CreateR4(
 | 
			
		||||
      {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
 | 
			
		||||
        {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
 | 
			
		||||
  auto input = ConstantLiteral(&builder, *input_literal);
 | 
			
		||||
  auto input = ConstantLiteral(&builder, input_literal);
 | 
			
		||||
 | 
			
		||||
  // Create dynamic slice start indices as a parameter: shape [4]
 | 
			
		||||
  auto start_indices_shape = ShapeUtil::MakeShape(S32, {4});
 | 
			
		||||
@ -740,7 +739,7 @@ void BM_DynamicSlice(int num_iters) {
 | 
			
		||||
  auto stream =
 | 
			
		||||
      client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
 | 
			
		||||
  ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
 | 
			
		||||
      stream.get(), *start_indices_literal, buffer));
 | 
			
		||||
      stream.get(), start_indices_literal, buffer));
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<LocalExecutable> executable =
 | 
			
		||||
      client
 | 
			
		||||
 | 
			
		||||
@ -31,7 +31,7 @@ XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) {
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      std::unique_ptr<GlobalData> input,
 | 
			
		||||
      client_->TransferToServer(
 | 
			
		||||
          *LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256)));
 | 
			
		||||
          LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256)));
 | 
			
		||||
 | 
			
		||||
  XlaBuilder b(TestName() + ".add");
 | 
			
		||||
  Dot(Parameter(&b, 0, shape, "param_0"), Parameter(&b, 1, shape, "param_1"));
 | 
			
		||||
 | 
			
		||||
@ -38,7 +38,7 @@ class ExhaustiveF32ElementwiseOpTest
 | 
			
		||||
 | 
			
		||||
    XlaBuilder builder(TestName());
 | 
			
		||||
 | 
			
		||||
    std::unique_ptr<Literal> input_literal =
 | 
			
		||||
    Literal input_literal =
 | 
			
		||||
        LiteralUtil::CreateFromDimensions(F32, {input_size});
 | 
			
		||||
    for (int64 i = begin; i < end; i++) {
 | 
			
		||||
      if (i >= known_incorrect_range.first &&
 | 
			
		||||
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user