Introduce new class Literal to replace protobuf Literal.

This renames the existing Literal message to LiteralProto and introduces a new
C++ class named Literal to replace it.

The LiteralProto is only used at RPC boundaries, or when protobuf-specific
functionality is required.  The Literal class offers a 'ToProto' function to
generate a new LiteralProto message when necessary.

Currently, all the static functions in class LiteralUtil, just forward to their
counterparts in class Literal.  This will change in a future CL.

Class Literal implements all the buffers as std::vectors.  The only exception
is preds(), which given the std::vector<bool> representation, makes it unusable
for the semantics we require (it's not possible to get the address of the
underlying vector, for instance).

The CL adds a BoolVector class to work around that issue.

In future CLs, the std::vector representation may be changed to something more
efficient, if needed.

PiperOrigin-RevId: 157739125
This commit is contained in:
A. Unique TensorFlower 2017-06-01 11:30:36 -07:00 committed by TensorFlower Gardener
parent 2b75a9a6ea
commit 02ac85399d
31 changed files with 1619 additions and 865 deletions

View File

@ -18,6 +18,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"

View File

@ -58,14 +58,13 @@ StatusOr<std::unique_ptr<Literal>> Client::Transfer(
"server provided response without a literal in "
"TransferToClient request");
}
return WrapUnique(response.release_literal());
return MakeUnique<Literal>(response.literal());
}
StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
const Literal& literal, const DeviceHandle* device_handle) {
TransferToServerRequest request;
*request.mutable_literal() = literal;
*request.mutable_literal() = literal.ToProto();
if (device_handle) {
*request.mutable_device_handle() = *device_handle;
}
@ -93,7 +92,7 @@ StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
Status Client::TransferToInfeed(const Literal& literal, int64 replica_id,
const DeviceHandle* device_handle) {
TransferToInfeedRequest request;
*request.mutable_literal() = literal;
*request.mutable_literal() = literal.ToProto();
if (device_handle) {
*request.mutable_device_handle() = *device_handle;
}
@ -141,7 +140,8 @@ StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed(
"TransferToClient request");
}
return WrapUnique(response.release_literal());
Literal literal(response.literal());
return MakeUnique<Literal>(literal);
}
Status Client::ResetDevice() {

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/service_interface.h"
#include "tensorflow/compiler/xla/statusor.h"

View File

@ -165,9 +165,10 @@ ComputationDataHandle ComputationBuilder::ConstantOp(
}
ConstantRequest request;
Literal* literal = request.mutable_literal();
populate(literal);
VLOG(3) << "created constant: " << literal->ShortDebugString();
Literal literal;
populate(&literal);
*request.mutable_literal() = literal.ToProto();
VLOG(3) << "created constant: " << request.literal().ShortDebugString();
OpRequest op_request;
*op_request.mutable_constant_request() = request;
*op_request.mutable_computation() = computation_.handle();

View File

@ -222,8 +222,9 @@ tensorflow::Status LocalExecutable::RecordArguments(
SessionModule* session_module) {
session_module->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
TF_RETURN_IF_ERROR(
LiteralFromShapedBuffer(*argument, session_module->add_arguments()));
Literal literal;
TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*argument, &literal));
*session_module->add_arguments() = literal.ToProto();
}
return tensorflow::Status::OK();
}
@ -231,9 +232,13 @@ tensorflow::Status LocalExecutable::RecordArguments(
tensorflow::Status LocalExecutable::RecordResult(
const ShapedBuffer* result, SessionModule* session_module) {
session_module->clear_result();
return LiteralFromShapedBuffer(*result, session_module->mutable_result());
Literal literal(session_module->result());
TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*result, &literal));
*session_module->mutable_result() = literal.ToProto();
return tensorflow::Status::OK();
}
// TODO(dnovillo) Change signature to return StatusOr<Literal>.
tensorflow::Status LocalExecutable::LiteralFromShapedBuffer(
const ShapedBuffer& shaped_buffer, Literal* literal) {
TF_ASSIGN_OR_RETURN(

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -60,8 +60,8 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
int64 elements = ShapeUtil::ElementsIn(shape);
LiteralUtil::Resize(elements, std::numeric_limits<float>::quiet_NaN(),
result.get());
tensorflow::protobuf::RepeatedField<float>* field = result->mutable_f32s();
char* data = tensorflow::bit_cast<char*>(field->mutable_data());
std::vector<float>* field = result->mutable_f32s();
char* data = tensorflow::bit_cast<char*>(field->data());
uint64 bytes = elements * sizeof(float);
tensorflow::StringPiece sp;
auto s = file_->Read(offset_, bytes, &sp, data);

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"

View File

@ -531,6 +531,7 @@ cc_library(
srcs = ["transfer_manager.cc"],
hdrs = ["transfer_manager.h"],
deps = [
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/types.h"

View File

@ -46,7 +46,7 @@ message HloInstructionProto {
xla.OpMetadata metadata = 7;
// Literal, only present for kConstant.
xla.Literal literal = 8;
xla.LiteralProto literal = 8;
// Parameter info, only present for kParameter.
int64 parameter_number = 9;

View File

@ -65,7 +65,7 @@ using ::tensorflow::strings::StrCat;
WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()));
instruction->operands_.push_back(operand);
instruction->literal_.reset(new Literal);
*instruction->literal_->mutable_u8s() += tag;
instruction->literal_->append_u8s(tag);
return instruction;
}
@ -1551,7 +1551,7 @@ HloInstructionProto HloInstruction::ToProto() const {
*proto.mutable_metadata() = metadata_;
switch (opcode_) {
case HloOpcode::kConstant:
*proto.mutable_literal() = *literal_;
*proto.mutable_literal() = literal_->ToProto();
break;
case HloOpcode::kParameter:
proto.set_parameter_number(parameter_number_);
@ -1648,10 +1648,10 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
trace_instruction_ = trace_instruction;
}
const string& HloInstruction::tracing_tag() const {
string HloInstruction::TracingTag() const {
CHECK_EQ(HloOpcode::kTrace, opcode());
CHECK(literal_ != nullptr);
return literal_->u8s();
return literal_->u8s_string();
}
bool HloInstruction::IsFused() const {

View File

@ -30,6 +30,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
@ -535,7 +536,7 @@ class HloInstruction {
// Returns a tag to be used in tracing.
//
// Precondition: opcode() == HloOpcode::kTrace
const string& tracing_tag() const;
string TracingTag() const;
// Returns whether the instruction is a constant.
bool IsConstant() const;

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "external/llvm/include/llvm/IR/Module.h"
#include "external/llvm/include/llvm/IR/Value.h"
#include "external/llvm/include/llvm/Support/raw_ostream.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"

View File

@ -77,8 +77,10 @@ tensorflow::Status RecordArguments(
SessionModule* module) {
module->clear_arguments();
for (const Allocation* allocation : arg_allocations) {
TF_RETURN_IF_ERROR(LiteralFromAllocation(allocation, allocation->shape(),
module->add_arguments()));
Literal argument;
TF_RETURN_IF_ERROR(
LiteralFromAllocation(allocation, allocation->shape(), &argument));
*module->add_arguments() = argument.ToProto();
}
return tensorflow::Status::OK();
}
@ -87,8 +89,11 @@ tensorflow::Status RecordArguments(
tensorflow::Status RecordResult(const Allocation* result_allocation,
SessionModule* module) {
module->clear_result();
return LiteralFromAllocation(result_allocation, result_allocation->shape(),
module->mutable_result());
Literal result;
TF_RETURN_IF_ERROR(LiteralFromAllocation(
result_allocation, result_allocation->shape(), &result));
*module->mutable_result() = result.ToProto();
return tensorflow::Status::OK();
}
} // namespace
@ -912,13 +917,15 @@ tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg,
literal_shape = &allocation->shape();
}
return LiteralFromAllocation(allocation, *literal_shape,
result->mutable_literal());
Literal literal;
auto status = LiteralFromAllocation(allocation, *literal_shape, &literal);
*result->mutable_literal() = literal.ToProto();
return status;
}
tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
TransferToServerResponse* result) {
const Literal& literal = arg->literal();
Literal literal = Literal(arg->literal());
const Shape& shape = literal.shape();
if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) {
@ -982,7 +989,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
}
return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
executor, arg->literal());
executor, Literal(arg->literal()));
}
tensorflow::Status Service::TransferFromOutfeed(
@ -1005,8 +1012,12 @@ tensorflow::Status Service::TransferFromOutfeed(
executor = execute_backend_->Replicas()[arg->replica_id()];
}
return execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
executor, arg->shape_with_layout(), result->mutable_literal());
Literal literal;
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
executor, arg->shape_with_layout(), &literal));
*result->mutable_literal() = literal.ToProto();
return tensorflow::Status::OK();
}
tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg,

View File

@ -75,10 +75,10 @@ message SessionModule {
repeated SessionComputation embedded_computations = 2;
// The arguments passed to the computation.
repeated Literal arguments = 3;
repeated LiteralProto arguments = 3;
// The result of the computation.
Literal result = 4;
LiteralProto result = 4;
// The name of the platform used to run the computation.
string execution_platform = 5;

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <set>
#include <vector>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"

View File

@ -121,7 +121,7 @@ TEST_F(CpuTransferManagerTest, TransferR1U8FromDevice) {
const Shape shape = ShapeUtil::MakeShape(U8, {4});
TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice(
stream_exec_, memptr, shape, shape, &literal));
CHECK_EQ("klmn", literal.u8s());
CHECK_EQ("klmn", literal.u8s_string());
}
TEST_F(CpuTransferManagerTest, TransferBufferFromDevice) {

View File

@ -2275,7 +2275,7 @@ void ComputationLowerer::Visit(
const ConstantRequest& constant_request =
request.request().constant_request();
hlo_instruction = add_instruction(HloInstruction::CreateConstant(
LiteralUtil::CloneToUnique(constant_request.literal())));
LiteralUtil::CloneToUnique(Literal(constant_request.literal()))));
break;
}

View File

@ -50,7 +50,7 @@ TEST_F(UserComputationTest, SimpleComputation) {
ConstantRequest constant_request;
*constant_request.mutable_literal() =
*LiteralUtil::CreateR1<float>({123.0f, 42.0f});
LiteralUtil::CreateR1<float>({123.0f, 42.0f})->ToProto();
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle constant_handle,
computation.AddConstantInstruction(constant_request));
@ -160,12 +160,13 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) {
UserComputation computation("TheComputation", handle);
ConstantRequest a_request;
*a_request.mutable_literal() = *LiteralUtil::CreateR1<float>({123.0f, 42.0f});
*a_request.mutable_literal() =
LiteralUtil::CreateR1<float>({123.0f, 42.0f})->ToProto();
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle,
computation.AddConstantInstruction(a_request));
ConstantRequest b_request;
*b_request.mutable_literal() = *LiteralUtil::CreateR0<float>(1.0f);
*b_request.mutable_literal() = LiteralUtil::CreateR0<float>(1.0f)->ToProto();
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle,
computation.AddConstantInstruction(b_request));

View File

@ -179,7 +179,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
VLOG(1) << "expected: " << LiteralUtil::ToString(*expected_literal);
VLOG(1) << "actual: " << LiteralUtil::ToString(*actual);
EXPECT_EQ(expected, actual->u8s());
EXPECT_EQ(expected, actual->u8s_string());
}
void ClientLibraryTestBase::ComputeAndCompareTuple(

View File

@ -262,7 +262,7 @@ class NearComparator {
max_abs_err_ = 0.0;
*miscompares_.mutable_shape() =
ShapeUtil::ChangeElementType(actual.shape(), PRED);
miscompares_.mutable_preds()->Resize(
miscompares_.mutable_preds()->resize(
ShapeUtil::ElementsIn(miscompares_.shape()), false);
multi_index_.resize(expected.shape().dimensions_size(), 0);
@ -389,7 +389,7 @@ class NearComparator {
tensorflow::strings::Printf("tempfile-%s-%llx-%s", Hostname().c_str(),
now_usec, name.c_str()));
TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(),
filename, literal));
filename, literal.ToProto()));
LOG(ERROR) << "wrote to " << name << " file: " << filename;
}

View File

@ -83,9 +83,10 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
LOG(INFO) << "results: [" << tensorflow::str_util::Join(results, ", ") << "]";
EXPECT_EQ(3, results.size());
for (const string& result : results) {
Literal literal;
LiteralProto literal_proto;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
&literal));
&literal_proto));
Literal literal(literal_proto);
if (result.find("expected") != string::npos) {
EXPECT_EQ("2", LiteralUtil::ToString(literal));
} else if (result.find("actual") != string::npos) {

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
#define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"

View File

@ -66,7 +66,8 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(
if (use_fake_data) {
arguments = MakeFakeArgumentsOrDie(computation, client);
} else { // use recorded data if available
for (const Literal& literal : module.arguments()) {
for (const auto& proto : module.arguments()) {
Literal literal(proto);
TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data,
client->TransferToServer(literal));
arguments.push_back(std::move(data));
@ -101,7 +102,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool use_fake_data) {
if (module.has_result()) {
fprintf(stdout, "was %s:%s\n",
ShapeUtil::HumanString(module.result().shape()).c_str(),
LiteralUtil::ToString(module.result()).c_str());
LiteralUtil::ToString(Literal(module.result())).c_str());
}
}
}

View File

@ -37,9 +37,10 @@ int main(int argc, char **argv) {
<< " <path-to-serialized-literal-proto>";
}
xla::Literal literal;
xla::LiteralProto literal_proto;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1],
&literal));
LOG(INFO) << "literal: " << literal.ShortDebugString();
&literal_proto));
xla::Literal literal(literal_proto);
LOG(INFO) << "literal: " << literal_proto.ShortDebugString();
fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(literal).c_str());
}

View File

@ -92,11 +92,11 @@ message TransferToClientRequest {
}
message TransferToClientResponse {
Literal literal = 1;
LiteralProto literal = 1;
}
message TransferToServerRequest {
Literal literal = 1;
LiteralProto literal = 1;
DeviceHandle device_handle = 2;
}
@ -105,7 +105,7 @@ message TransferToServerResponse {
}
message TransferToInfeedRequest {
Literal literal = 1;
LiteralProto literal = 1;
int64 replica_id = 2;
DeviceHandle device_handle = 3;
}
@ -123,7 +123,7 @@ message TransferFromOutfeedRequest {
}
message TransferFromOutfeedResponse {
Literal literal = 1;
LiteralProto literal = 1;
}
message ResetDeviceRequest {

View File

@ -275,7 +275,7 @@ message ChannelHandle {
//
// Transfers to/from the client are encoded in literal form, and the structure
// of the repeated fields is implied by the shape.
message Literal {
message LiteralProto {
Shape shape = 1;
repeated bool preds = 2;
bytes u8s = 3;
@ -285,7 +285,7 @@ message Literal {
repeated uint64 u64s = 7;
repeated float f32s = 8;
repeated double f64s = 9;
repeated Literal tuple_literals = 10;
repeated LiteralProto tuple_literals = 10;
bytes f16s = 11; // Note: the F16s are encoded in little endian byte order
}
@ -337,7 +337,7 @@ message Window {
// field in OpRequest.
message ConstantRequest {
Literal literal = 2;
LiteralProto literal = 2;
}
message GetTupleElementRequest {