Introduce new class Literal to replace protobuf Literal.
This renames the existing Literal message to LiteralProto and introduces a new C++ class named Literal to replace it. The LiteralProto is only used at RPC boundaries, or when protobuf-specific functionality is required. The Literal class offers a 'ToProto' function to generate a new LiteralProto message when necessary. Currently, all the static functions in class LiteralUtil, just forward to their counterparts in class Literal. This will change in a future CL. Class Literal implements all the buffers as std::vectors. The only exception is preds(), which given the std::vector<bool> representation, makes it unusable for the semantics we require (it's not possible to get the address of the underlying vector, for instance). The CL adds a BoolVector class to work around that issue. In future CLs, the std::vector representation may be changed to something more efficient, if needed. PiperOrigin-RevId: 157739125
This commit is contained in:
parent
2b75a9a6ea
commit
02ac85399d
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
|
#ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
|
||||||
#define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
|
#define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
@ -58,14 +58,13 @@ StatusOr<std::unique_ptr<Literal>> Client::Transfer(
|
|||||||
"server provided response without a literal in "
|
"server provided response without a literal in "
|
||||||
"TransferToClient request");
|
"TransferToClient request");
|
||||||
}
|
}
|
||||||
|
return MakeUnique<Literal>(response.literal());
|
||||||
return WrapUnique(response.release_literal());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
|
StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
|
||||||
const Literal& literal, const DeviceHandle* device_handle) {
|
const Literal& literal, const DeviceHandle* device_handle) {
|
||||||
TransferToServerRequest request;
|
TransferToServerRequest request;
|
||||||
*request.mutable_literal() = literal;
|
*request.mutable_literal() = literal.ToProto();
|
||||||
if (device_handle) {
|
if (device_handle) {
|
||||||
*request.mutable_device_handle() = *device_handle;
|
*request.mutable_device_handle() = *device_handle;
|
||||||
}
|
}
|
||||||
@ -93,7 +92,7 @@ StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
|
|||||||
Status Client::TransferToInfeed(const Literal& literal, int64 replica_id,
|
Status Client::TransferToInfeed(const Literal& literal, int64 replica_id,
|
||||||
const DeviceHandle* device_handle) {
|
const DeviceHandle* device_handle) {
|
||||||
TransferToInfeedRequest request;
|
TransferToInfeedRequest request;
|
||||||
*request.mutable_literal() = literal;
|
*request.mutable_literal() = literal.ToProto();
|
||||||
if (device_handle) {
|
if (device_handle) {
|
||||||
*request.mutable_device_handle() = *device_handle;
|
*request.mutable_device_handle() = *device_handle;
|
||||||
}
|
}
|
||||||
@ -141,7 +140,8 @@ StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed(
|
|||||||
"TransferToClient request");
|
"TransferToClient request");
|
||||||
}
|
}
|
||||||
|
|
||||||
return WrapUnique(response.release_literal());
|
Literal literal(response.literal());
|
||||||
|
return MakeUnique<Literal>(literal);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Client::ResetDevice() {
|
Status Client::ResetDevice() {
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/client/computation.h"
|
#include "tensorflow/compiler/xla/client/computation.h"
|
||||||
#include "tensorflow/compiler/xla/client/global_data.h"
|
#include "tensorflow/compiler/xla/client/global_data.h"
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/session.pb.h"
|
#include "tensorflow/compiler/xla/service/session.pb.h"
|
||||||
#include "tensorflow/compiler/xla/service_interface.h"
|
#include "tensorflow/compiler/xla/service_interface.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
@ -165,9 +165,10 @@ ComputationDataHandle ComputationBuilder::ConstantOp(
|
|||||||
}
|
}
|
||||||
|
|
||||||
ConstantRequest request;
|
ConstantRequest request;
|
||||||
Literal* literal = request.mutable_literal();
|
Literal literal;
|
||||||
populate(literal);
|
populate(&literal);
|
||||||
VLOG(3) << "created constant: " << literal->ShortDebugString();
|
*request.mutable_literal() = literal.ToProto();
|
||||||
|
VLOG(3) << "created constant: " << request.literal().ShortDebugString();
|
||||||
OpRequest op_request;
|
OpRequest op_request;
|
||||||
*op_request.mutable_constant_request() = request;
|
*op_request.mutable_constant_request() = request;
|
||||||
*op_request.mutable_computation() = computation_.handle();
|
*op_request.mutable_computation() = computation_.handle();
|
||||||
|
@ -222,8 +222,9 @@ tensorflow::Status LocalExecutable::RecordArguments(
|
|||||||
SessionModule* session_module) {
|
SessionModule* session_module) {
|
||||||
session_module->clear_arguments();
|
session_module->clear_arguments();
|
||||||
for (const ShapedBuffer* argument : arguments) {
|
for (const ShapedBuffer* argument : arguments) {
|
||||||
TF_RETURN_IF_ERROR(
|
Literal literal;
|
||||||
LiteralFromShapedBuffer(*argument, session_module->add_arguments()));
|
TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*argument, &literal));
|
||||||
|
*session_module->add_arguments() = literal.ToProto();
|
||||||
}
|
}
|
||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
@ -231,9 +232,13 @@ tensorflow::Status LocalExecutable::RecordArguments(
|
|||||||
tensorflow::Status LocalExecutable::RecordResult(
|
tensorflow::Status LocalExecutable::RecordResult(
|
||||||
const ShapedBuffer* result, SessionModule* session_module) {
|
const ShapedBuffer* result, SessionModule* session_module) {
|
||||||
session_module->clear_result();
|
session_module->clear_result();
|
||||||
return LiteralFromShapedBuffer(*result, session_module->mutable_result());
|
Literal literal(session_module->result());
|
||||||
|
TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*result, &literal));
|
||||||
|
*session_module->mutable_result() = literal.ToProto();
|
||||||
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(dnovillo) Change signature to return StatusOr<Literal>.
|
||||||
tensorflow::Status LocalExecutable::LiteralFromShapedBuffer(
|
tensorflow::Status LocalExecutable::LiteralFromShapedBuffer(
|
||||||
const ShapedBuffer& shaped_buffer, Literal* literal) {
|
const ShapedBuffer& shaped_buffer, Literal* literal) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -60,8 +60,8 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
|
|||||||
int64 elements = ShapeUtil::ElementsIn(shape);
|
int64 elements = ShapeUtil::ElementsIn(shape);
|
||||||
LiteralUtil::Resize(elements, std::numeric_limits<float>::quiet_NaN(),
|
LiteralUtil::Resize(elements, std::numeric_limits<float>::quiet_NaN(),
|
||||||
result.get());
|
result.get());
|
||||||
tensorflow::protobuf::RepeatedField<float>* field = result->mutable_f32s();
|
std::vector<float>* field = result->mutable_f32s();
|
||||||
char* data = tensorflow::bit_cast<char*>(field->mutable_data());
|
char* data = tensorflow::bit_cast<char*>(field->data());
|
||||||
uint64 bytes = elements * sizeof(float);
|
uint64 bytes = elements * sizeof(float);
|
||||||
tensorflow::StringPiece sp;
|
tensorflow::StringPiece sp;
|
||||||
auto s = file_->Read(offset_, bytes, &sp, data);
|
auto s = file_->Read(offset_, bytes, &sp, data);
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
@ -531,6 +531,7 @@ cc_library(
|
|||||||
srcs = ["transfer_manager.cc"],
|
srcs = ["transfer_manager.cc"],
|
||||||
hdrs = ["transfer_manager.h"],
|
hdrs = ["transfer_manager.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
#include "tensorflow/compiler/xla/status.h"
|
#include "tensorflow/compiler/xla/status.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
|
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
@ -46,7 +46,7 @@ message HloInstructionProto {
|
|||||||
xla.OpMetadata metadata = 7;
|
xla.OpMetadata metadata = 7;
|
||||||
|
|
||||||
// Literal, only present for kConstant.
|
// Literal, only present for kConstant.
|
||||||
xla.Literal literal = 8;
|
xla.LiteralProto literal = 8;
|
||||||
|
|
||||||
// Parameter info, only present for kParameter.
|
// Parameter info, only present for kParameter.
|
||||||
int64 parameter_number = 9;
|
int64 parameter_number = 9;
|
||||||
|
@ -65,7 +65,7 @@ using ::tensorflow::strings::StrCat;
|
|||||||
WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()));
|
WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()));
|
||||||
instruction->operands_.push_back(operand);
|
instruction->operands_.push_back(operand);
|
||||||
instruction->literal_.reset(new Literal);
|
instruction->literal_.reset(new Literal);
|
||||||
*instruction->literal_->mutable_u8s() += tag;
|
instruction->literal_->append_u8s(tag);
|
||||||
return instruction;
|
return instruction;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1551,7 +1551,7 @@ HloInstructionProto HloInstruction::ToProto() const {
|
|||||||
*proto.mutable_metadata() = metadata_;
|
*proto.mutable_metadata() = metadata_;
|
||||||
switch (opcode_) {
|
switch (opcode_) {
|
||||||
case HloOpcode::kConstant:
|
case HloOpcode::kConstant:
|
||||||
*proto.mutable_literal() = *literal_;
|
*proto.mutable_literal() = literal_->ToProto();
|
||||||
break;
|
break;
|
||||||
case HloOpcode::kParameter:
|
case HloOpcode::kParameter:
|
||||||
proto.set_parameter_number(parameter_number_);
|
proto.set_parameter_number(parameter_number_);
|
||||||
@ -1648,10 +1648,10 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
|
|||||||
trace_instruction_ = trace_instruction;
|
trace_instruction_ = trace_instruction;
|
||||||
}
|
}
|
||||||
|
|
||||||
const string& HloInstruction::tracing_tag() const {
|
string HloInstruction::TracingTag() const {
|
||||||
CHECK_EQ(HloOpcode::kTrace, opcode());
|
CHECK_EQ(HloOpcode::kTrace, opcode());
|
||||||
CHECK(literal_ != nullptr);
|
CHECK(literal_ != nullptr);
|
||||||
return literal_->u8s();
|
return literal_->u8s_string();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HloInstruction::IsFused() const {
|
bool HloInstruction::IsFused() const {
|
||||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
|
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
|
||||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||||
@ -535,7 +536,7 @@ class HloInstruction {
|
|||||||
// Returns a tag to be used in tracing.
|
// Returns a tag to be used in tracing.
|
||||||
//
|
//
|
||||||
// Precondition: opcode() == HloOpcode::kTrace
|
// Precondition: opcode() == HloOpcode::kTrace
|
||||||
const string& tracing_tag() const;
|
string TracingTag() const;
|
||||||
|
|
||||||
// Returns whether the instruction is a constant.
|
// Returns whether the instruction is a constant.
|
||||||
bool IsConstant() const;
|
bool IsConstant() const;
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "external/llvm/include/llvm/IR/Module.h"
|
#include "external/llvm/include/llvm/IR/Module.h"
|
||||||
#include "external/llvm/include/llvm/IR/Value.h"
|
#include "external/llvm/include/llvm/IR/Value.h"
|
||||||
#include "external/llvm/include/llvm/Support/raw_ostream.h"
|
#include "external/llvm/include/llvm/Support/raw_ostream.h"
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
|
@ -77,8 +77,10 @@ tensorflow::Status RecordArguments(
|
|||||||
SessionModule* module) {
|
SessionModule* module) {
|
||||||
module->clear_arguments();
|
module->clear_arguments();
|
||||||
for (const Allocation* allocation : arg_allocations) {
|
for (const Allocation* allocation : arg_allocations) {
|
||||||
TF_RETURN_IF_ERROR(LiteralFromAllocation(allocation, allocation->shape(),
|
Literal argument;
|
||||||
module->add_arguments()));
|
TF_RETURN_IF_ERROR(
|
||||||
|
LiteralFromAllocation(allocation, allocation->shape(), &argument));
|
||||||
|
*module->add_arguments() = argument.ToProto();
|
||||||
}
|
}
|
||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
@ -87,8 +89,11 @@ tensorflow::Status RecordArguments(
|
|||||||
tensorflow::Status RecordResult(const Allocation* result_allocation,
|
tensorflow::Status RecordResult(const Allocation* result_allocation,
|
||||||
SessionModule* module) {
|
SessionModule* module) {
|
||||||
module->clear_result();
|
module->clear_result();
|
||||||
return LiteralFromAllocation(result_allocation, result_allocation->shape(),
|
Literal result;
|
||||||
module->mutable_result());
|
TF_RETURN_IF_ERROR(LiteralFromAllocation(
|
||||||
|
result_allocation, result_allocation->shape(), &result));
|
||||||
|
*module->mutable_result() = result.ToProto();
|
||||||
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -912,13 +917,15 @@ tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg,
|
|||||||
literal_shape = &allocation->shape();
|
literal_shape = &allocation->shape();
|
||||||
}
|
}
|
||||||
|
|
||||||
return LiteralFromAllocation(allocation, *literal_shape,
|
Literal literal;
|
||||||
result->mutable_literal());
|
auto status = LiteralFromAllocation(allocation, *literal_shape, &literal);
|
||||||
|
*result->mutable_literal() = literal.ToProto();
|
||||||
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
|
tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
|
||||||
TransferToServerResponse* result) {
|
TransferToServerResponse* result) {
|
||||||
const Literal& literal = arg->literal();
|
Literal literal = Literal(arg->literal());
|
||||||
const Shape& shape = literal.shape();
|
const Shape& shape = literal.shape();
|
||||||
|
|
||||||
if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) {
|
if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) {
|
||||||
@ -982,7 +989,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
|
|||||||
}
|
}
|
||||||
|
|
||||||
return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
|
return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
|
||||||
executor, arg->literal());
|
executor, Literal(arg->literal()));
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::Status Service::TransferFromOutfeed(
|
tensorflow::Status Service::TransferFromOutfeed(
|
||||||
@ -1005,8 +1012,12 @@ tensorflow::Status Service::TransferFromOutfeed(
|
|||||||
executor = execute_backend_->Replicas()[arg->replica_id()];
|
executor = execute_backend_->Replicas()[arg->replica_id()];
|
||||||
}
|
}
|
||||||
|
|
||||||
return execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
|
Literal literal;
|
||||||
executor, arg->shape_with_layout(), result->mutable_literal());
|
TF_RETURN_IF_ERROR(
|
||||||
|
execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
|
||||||
|
executor, arg->shape_with_layout(), &literal));
|
||||||
|
*result->mutable_literal() = literal.ToProto();
|
||||||
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg,
|
tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg,
|
||||||
|
@ -75,10 +75,10 @@ message SessionModule {
|
|||||||
repeated SessionComputation embedded_computations = 2;
|
repeated SessionComputation embedded_computations = 2;
|
||||||
|
|
||||||
// The arguments passed to the computation.
|
// The arguments passed to the computation.
|
||||||
repeated Literal arguments = 3;
|
repeated LiteralProto arguments = 3;
|
||||||
|
|
||||||
// The result of the computation.
|
// The result of the computation.
|
||||||
Literal result = 4;
|
LiteralProto result = 4;
|
||||||
|
|
||||||
// The name of the platform used to run the computation.
|
// The name of the platform used to run the computation.
|
||||||
string execution_platform = 5;
|
string execution_platform = 5;
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <set>
|
#include <set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
@ -121,7 +121,7 @@ TEST_F(CpuTransferManagerTest, TransferR1U8FromDevice) {
|
|||||||
const Shape shape = ShapeUtil::MakeShape(U8, {4});
|
const Shape shape = ShapeUtil::MakeShape(U8, {4});
|
||||||
TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice(
|
TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice(
|
||||||
stream_exec_, memptr, shape, shape, &literal));
|
stream_exec_, memptr, shape, shape, &literal));
|
||||||
CHECK_EQ("klmn", literal.u8s());
|
CHECK_EQ("klmn", literal.u8s_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CpuTransferManagerTest, TransferBufferFromDevice) {
|
TEST_F(CpuTransferManagerTest, TransferBufferFromDevice) {
|
||||||
|
@ -2275,7 +2275,7 @@ void ComputationLowerer::Visit(
|
|||||||
const ConstantRequest& constant_request =
|
const ConstantRequest& constant_request =
|
||||||
request.request().constant_request();
|
request.request().constant_request();
|
||||||
hlo_instruction = add_instruction(HloInstruction::CreateConstant(
|
hlo_instruction = add_instruction(HloInstruction::CreateConstant(
|
||||||
LiteralUtil::CloneToUnique(constant_request.literal())));
|
LiteralUtil::CloneToUnique(Literal(constant_request.literal()))));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ TEST_F(UserComputationTest, SimpleComputation) {
|
|||||||
|
|
||||||
ConstantRequest constant_request;
|
ConstantRequest constant_request;
|
||||||
*constant_request.mutable_literal() =
|
*constant_request.mutable_literal() =
|
||||||
*LiteralUtil::CreateR1<float>({123.0f, 42.0f});
|
LiteralUtil::CreateR1<float>({123.0f, 42.0f})->ToProto();
|
||||||
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle constant_handle,
|
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle constant_handle,
|
||||||
computation.AddConstantInstruction(constant_request));
|
computation.AddConstantInstruction(constant_request));
|
||||||
|
|
||||||
@ -160,12 +160,13 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) {
|
|||||||
UserComputation computation("TheComputation", handle);
|
UserComputation computation("TheComputation", handle);
|
||||||
|
|
||||||
ConstantRequest a_request;
|
ConstantRequest a_request;
|
||||||
*a_request.mutable_literal() = *LiteralUtil::CreateR1<float>({123.0f, 42.0f});
|
*a_request.mutable_literal() =
|
||||||
|
LiteralUtil::CreateR1<float>({123.0f, 42.0f})->ToProto();
|
||||||
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle,
|
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle,
|
||||||
computation.AddConstantInstruction(a_request));
|
computation.AddConstantInstruction(a_request));
|
||||||
|
|
||||||
ConstantRequest b_request;
|
ConstantRequest b_request;
|
||||||
*b_request.mutable_literal() = *LiteralUtil::CreateR0<float>(1.0f);
|
*b_request.mutable_literal() = LiteralUtil::CreateR0<float>(1.0f)->ToProto();
|
||||||
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle,
|
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle,
|
||||||
computation.AddConstantInstruction(b_request));
|
computation.AddConstantInstruction(b_request));
|
||||||
|
|
||||||
|
@ -179,7 +179,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
|
|||||||
VLOG(1) << "expected: " << LiteralUtil::ToString(*expected_literal);
|
VLOG(1) << "expected: " << LiteralUtil::ToString(*expected_literal);
|
||||||
VLOG(1) << "actual: " << LiteralUtil::ToString(*actual);
|
VLOG(1) << "actual: " << LiteralUtil::ToString(*actual);
|
||||||
|
|
||||||
EXPECT_EQ(expected, actual->u8s());
|
EXPECT_EQ(expected, actual->u8s_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
void ClientLibraryTestBase::ComputeAndCompareTuple(
|
void ClientLibraryTestBase::ComputeAndCompareTuple(
|
||||||
|
@ -262,7 +262,7 @@ class NearComparator {
|
|||||||
max_abs_err_ = 0.0;
|
max_abs_err_ = 0.0;
|
||||||
*miscompares_.mutable_shape() =
|
*miscompares_.mutable_shape() =
|
||||||
ShapeUtil::ChangeElementType(actual.shape(), PRED);
|
ShapeUtil::ChangeElementType(actual.shape(), PRED);
|
||||||
miscompares_.mutable_preds()->Resize(
|
miscompares_.mutable_preds()->resize(
|
||||||
ShapeUtil::ElementsIn(miscompares_.shape()), false);
|
ShapeUtil::ElementsIn(miscompares_.shape()), false);
|
||||||
multi_index_.resize(expected.shape().dimensions_size(), 0);
|
multi_index_.resize(expected.shape().dimensions_size(), 0);
|
||||||
|
|
||||||
@ -389,7 +389,7 @@ class NearComparator {
|
|||||||
tensorflow::strings::Printf("tempfile-%s-%llx-%s", Hostname().c_str(),
|
tensorflow::strings::Printf("tempfile-%s-%llx-%s", Hostname().c_str(),
|
||||||
now_usec, name.c_str()));
|
now_usec, name.c_str()));
|
||||||
TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(),
|
TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(),
|
||||||
filename, literal));
|
filename, literal.ToProto()));
|
||||||
LOG(ERROR) << "wrote to " << name << " file: " << filename;
|
LOG(ERROR) << "wrote to " << name << " file: " << filename;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,9 +83,10 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
|
|||||||
LOG(INFO) << "results: [" << tensorflow::str_util::Join(results, ", ") << "]";
|
LOG(INFO) << "results: [" << tensorflow::str_util::Join(results, ", ") << "]";
|
||||||
EXPECT_EQ(3, results.size());
|
EXPECT_EQ(3, results.size());
|
||||||
for (const string& result : results) {
|
for (const string& result : results) {
|
||||||
Literal literal;
|
LiteralProto literal_proto;
|
||||||
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
|
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
|
||||||
&literal));
|
&literal_proto));
|
||||||
|
Literal literal(literal_proto);
|
||||||
if (result.find("expected") != string::npos) {
|
if (result.find("expected") != string::npos) {
|
||||||
EXPECT_EQ("2", LiteralUtil::ToString(literal));
|
EXPECT_EQ("2", LiteralUtil::ToString(literal));
|
||||||
} else if (result.find("actual") != string::npos) {
|
} else if (result.find("actual") != string::npos) {
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
|
#define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
@ -66,7 +66,8 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(
|
|||||||
if (use_fake_data) {
|
if (use_fake_data) {
|
||||||
arguments = MakeFakeArgumentsOrDie(computation, client);
|
arguments = MakeFakeArgumentsOrDie(computation, client);
|
||||||
} else { // use recorded data if available
|
} else { // use recorded data if available
|
||||||
for (const Literal& literal : module.arguments()) {
|
for (const auto& proto : module.arguments()) {
|
||||||
|
Literal literal(proto);
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data,
|
||||||
client->TransferToServer(literal));
|
client->TransferToServer(literal));
|
||||||
arguments.push_back(std::move(data));
|
arguments.push_back(std::move(data));
|
||||||
@ -101,7 +102,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool use_fake_data) {
|
|||||||
if (module.has_result()) {
|
if (module.has_result()) {
|
||||||
fprintf(stdout, "was %s:%s\n",
|
fprintf(stdout, "was %s:%s\n",
|
||||||
ShapeUtil::HumanString(module.result().shape()).c_str(),
|
ShapeUtil::HumanString(module.result().shape()).c_str(),
|
||||||
LiteralUtil::ToString(module.result()).c_str());
|
LiteralUtil::ToString(Literal(module.result())).c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -37,9 +37,10 @@ int main(int argc, char **argv) {
|
|||||||
<< " <path-to-serialized-literal-proto>";
|
<< " <path-to-serialized-literal-proto>";
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::Literal literal;
|
xla::LiteralProto literal_proto;
|
||||||
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1],
|
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1],
|
||||||
&literal));
|
&literal_proto));
|
||||||
LOG(INFO) << "literal: " << literal.ShortDebugString();
|
xla::Literal literal(literal_proto);
|
||||||
|
LOG(INFO) << "literal: " << literal_proto.ShortDebugString();
|
||||||
fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(literal).c_str());
|
fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(literal).c_str());
|
||||||
}
|
}
|
||||||
|
@ -92,11 +92,11 @@ message TransferToClientRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message TransferToClientResponse {
|
message TransferToClientResponse {
|
||||||
Literal literal = 1;
|
LiteralProto literal = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message TransferToServerRequest {
|
message TransferToServerRequest {
|
||||||
Literal literal = 1;
|
LiteralProto literal = 1;
|
||||||
DeviceHandle device_handle = 2;
|
DeviceHandle device_handle = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,7 +105,7 @@ message TransferToServerResponse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message TransferToInfeedRequest {
|
message TransferToInfeedRequest {
|
||||||
Literal literal = 1;
|
LiteralProto literal = 1;
|
||||||
int64 replica_id = 2;
|
int64 replica_id = 2;
|
||||||
DeviceHandle device_handle = 3;
|
DeviceHandle device_handle = 3;
|
||||||
}
|
}
|
||||||
@ -123,7 +123,7 @@ message TransferFromOutfeedRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message TransferFromOutfeedResponse {
|
message TransferFromOutfeedResponse {
|
||||||
Literal literal = 1;
|
LiteralProto literal = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message ResetDeviceRequest {
|
message ResetDeviceRequest {
|
||||||
|
@ -275,7 +275,7 @@ message ChannelHandle {
|
|||||||
//
|
//
|
||||||
// Transfers to/from the client are encoded in literal form, and the structure
|
// Transfers to/from the client are encoded in literal form, and the structure
|
||||||
// of the repeated fields is implied by the shape.
|
// of the repeated fields is implied by the shape.
|
||||||
message Literal {
|
message LiteralProto {
|
||||||
Shape shape = 1;
|
Shape shape = 1;
|
||||||
repeated bool preds = 2;
|
repeated bool preds = 2;
|
||||||
bytes u8s = 3;
|
bytes u8s = 3;
|
||||||
@ -285,7 +285,7 @@ message Literal {
|
|||||||
repeated uint64 u64s = 7;
|
repeated uint64 u64s = 7;
|
||||||
repeated float f32s = 8;
|
repeated float f32s = 8;
|
||||||
repeated double f64s = 9;
|
repeated double f64s = 9;
|
||||||
repeated Literal tuple_literals = 10;
|
repeated LiteralProto tuple_literals = 10;
|
||||||
bytes f16s = 11; // Note: the F16s are encoded in little endian byte order
|
bytes f16s = 11; // Note: the F16s are encoded in little endian byte order
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -337,7 +337,7 @@ message Window {
|
|||||||
// field in OpRequest.
|
// field in OpRequest.
|
||||||
|
|
||||||
message ConstantRequest {
|
message ConstantRequest {
|
||||||
Literal literal = 2;
|
LiteralProto literal = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GetTupleElementRequest {
|
message GetTupleElementRequest {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user