Replace ProgramShape proto with a C++ class.

Rename the protobuf message ProgramShape to ProgramShapeProto and create a new ProgramShape C++ class with an interface which mirrors the protobuf generated code interface. This CL is a step toward replacing Shape proto with a C++ class. ProgramShape needs to be migrated first because ProgramShape contains Shapes.

PiperOrigin-RevId: 222435461
This commit is contained in:
Mark Heffernan 2018-11-21 11:17:40 -08:00 committed by TensorFlower Gardener
parent f6ce9fd485
commit f22eec10b6
36 changed files with 408 additions and 106 deletions

View File

@ -164,7 +164,8 @@ string RewriteWithName(const string& name, string code,
}
// Generate methods for args (inputs).
Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
Status GenArgMethods(const tf2xla::Config& config,
const xla::ProgramShapeProto& ps,
const CompileResult& compile_result, string* methods) {
size_t num_args = ps.parameters_size();
if (config.feed_size() != num_args) {
@ -204,7 +205,7 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
// Generate methods for results (outputs).
Status GenResultMethods(const tf2xla::Config& config,
const xla::ProgramShape& ps, string* methods) {
const xla::ProgramShapeProto& ps, string* methods) {
if (ps.result().element_type() != xla::TUPLE) {
// The XlaCompiler we use to build the xla computation always generates a
// tuple result, and we rely on this to simplify code generation.
@ -336,7 +337,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
ExtractEntryParamBufferInfos(buffer_infos);
std::vector<BufferInfo> buffer_infos_for_temps =
ExtractTempBufferInfos(buffer_infos);
const xla::ProgramShape& ps = compile_result.program_shape;
const xla::ProgramShapeProto& ps = compile_result.program_shape;
string methods_arg, methods_result;
TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
@ -548,8 +549,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
// Shape of the args and results.
static const xla::ProgramShape* StaticProgramShape() {
static const xla::ProgramShape* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
static const xla::ProgramShapeProto* StaticProgramShape() {
static const xla::ProgramShapeProto* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
return kShape;
}
@ -615,11 +616,11 @@ static string CreateUniqueIdentifier(const CodegenOpts& opts,
Status GenerateMetadata(const CodegenOpts& opts,
const CompileResult& compile_result,
MetadataResult* metadata_result) {
std::unique_ptr<xla::ProgramShape> program_shape;
std::unique_ptr<xla::ProgramShapeProto> program_shape;
if (opts.gen_program_shape) {
program_shape =
absl::make_unique<xla::ProgramShape>(compile_result.program_shape);
absl::make_unique<xla::ProgramShapeProto>(compile_result.program_shape);
// The parameter names are currently meaningless, and redundant with the
// rest of our metadata, so clear them out to avoid confusion and save
@ -631,8 +632,8 @@ Status GenerateMetadata(const CodegenOpts& opts,
// a shim that evaluates to nullptr, which is what we want.
ProtobufToEmbed program_shape_protobuf{
CreateUniqueIdentifier(opts, "ProgramShape"), "xla::ProgramShape",
program_shape.get()};
CreateUniqueIdentifier(opts, "ProgramShapeProto"),
"xla::ProgramShapeProto", program_shape.get()};
ProtobufToEmbed hlo_profile_printer_data_protobuf{
CreateUniqueIdentifier(opts, "HloProfilePrinterData"),

View File

@ -57,7 +57,7 @@ struct MetadataResult {
std::vector<string> header_variable_decls;
// program_shape_access_shim is a C++ expression that constructs the
// xla::ProgramShape instance for the CompileResult passed to
// xla::ProgramShapeProto instance for the CompileResult passed to
// GenerateMetadata.
string program_shape_access_shim;

View File

@ -181,13 +181,15 @@ TEST(CodegenTest, Golden) {
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
5, {}));
compile_result.program_shape = xla::ShapeUtil::MakeProgramShape(
{
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
},
xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::U32, {5, 6})}));
compile_result.program_shape =
xla::ShapeUtil::MakeProgramShape(
{
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
},
xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::U32, {5, 6})}))
.ToProto();
compile_result.entry_point = "entry_point";
compile_result.pointer_size = 8;

View File

@ -22,7 +22,7 @@ extern "C" void entry_point(
void* result, const xla::ExecutableRunOptions* run_options,
const void** args, void** temps, tensorflow::int64* profile_counters);
extern "C" char __tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[];
extern "C" char __tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[];
namespace foo {
@ -253,10 +253,10 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
}
// Shape of the args and results.
static const xla::ProgramShape* StaticProgramShape() {
static const xla::ProgramShape* kShape = []() {
xla::ProgramShape* proto = new xla::ProgramShape;
proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[0], 52);
static const xla::ProgramShapeProto* StaticProgramShape() {
static const xla::ProgramShapeProto* kShape = []() {
xla::ProgramShapeProto* proto = new xla::ProgramShapeProto;
proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 52);
return proto;
}();
return kShape;

View File

@ -56,8 +56,8 @@ Status CompileXla(xla::CompileOnlyClient* client,
return errors::Unknown("Couldn't get XLA program shape: ",
pshape_or.status().error_message());
}
compile_result->program_shape = *pshape_or.ValueOrDie();
xla::ProgramShape* pshape = &compile_result->program_shape;
compile_result->program_shape = pshape_or.ValueOrDie()->ToProto();
xla::ProgramShapeProto* pshape = &compile_result->program_shape;
std::vector<const xla::Shape*> arg_layouts;
arg_layouts.reserve(pshape->parameters_size());
for (int i = 0; i < pshape->parameters_size(); ++i) {

View File

@ -33,9 +33,9 @@ namespace tfcompile {
struct CompileResult {
// Contains object file and meta-info.
std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot;
xla::ProgramShape program_shape; // Static shape of args and results.
string entry_point; // Name of generated function.
int pointer_size = 0; // Size of a pointer in bytes.
xla::ProgramShapeProto program_shape; // Static shape of args and results.
string entry_point; // Name of generated function.
int pointer_size = 0; // Size of a pointer in bytes.
};
// CompileGraph compiles the graph_def into an object file containing a function

View File

@ -526,7 +526,7 @@ TEST(TFCompileTest, ProgramShape) {
// muladd has the program shape defined.
MatMulAndAddComp muladd;
const xla::ProgramShape* muladd_shape = muladd.ProgramShape();
const xla::ProgramShapeProto* muladd_shape = muladd.ProgramShape();
ASSERT_TRUE(muladd_shape != nullptr);
ASSERT_EQ(muladd_shape->parameters_size(), 2);
EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(0), f32_2x2));

View File

@ -26,7 +26,7 @@ limitations under the License.
// Forward-declare, rather than include, to reduce code size for users that
// never use this functionality.
namespace xla {
class ProgramShape;
class ProgramShapeProto;
class HloProfilePrinterData;
}
@ -84,7 +84,7 @@ class XlaCompiledCpuFunction {
void set_result_names(const char** result_names) {
result_names_ = result_names;
}
void set_program_shape(const xla::ProgramShape* program_shape) {
void set_program_shape(const xla::ProgramShapeProto* program_shape) {
program_shape_ = program_shape;
}
const xla::HloProfilePrinterData* hlo_profile_printer_data() const {
@ -122,7 +122,7 @@ class XlaCompiledCpuFunction {
const char** result_names_ = nullptr;
// [Optional] Arg and result shapes.
const xla::ProgramShape* program_shape_ = nullptr;
const xla::ProgramShapeProto* program_shape_ = nullptr;
// [Optional] Profile printer data. Null if profiling is disabled.
const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
@ -264,7 +264,7 @@ class XlaCompiledCpuFunction {
// Returns the shape of the args and results. May return nullptr if the
// program shape isn't available.
const xla::ProgramShape* ProgramShape() const { return program_shape_; }
const xla::ProgramShapeProto* ProgramShape() const { return program_shape_; }
bool hlo_profiling_enabled() const {
return hlo_profile_printer_data_ != nullptr;
@ -305,7 +305,7 @@ class XlaCompiledCpuFunction {
// Optional metadata.
const char** arg_names_ = nullptr;
const char** result_names_ = nullptr;
const xla::ProgramShape* program_shape_ = nullptr;
const xla::ProgramShapeProto* program_shape_ = nullptr;
const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
};

View File

@ -133,7 +133,8 @@ XlaJitCompiledCpuFunction::Compile(
jit->executable_ = std::move(executable);
jit->buffer_infos_ = std::move(buffer_infos);
jit->arg_index_table_ = std::move(arg_index_table);
jit->program_shape_ = std::move(program_shape);
jit->program_shape_ =
absl::make_unique<xla::ProgramShapeProto>(program_shape->ToProto());
jit->static_data_.set_raw_function(raw_function);
jit->static_data_.set_buffer_infos(jit->buffer_infos_.data());
jit->static_data_.set_num_buffers(jit->buffer_infos_.size());

View File

@ -80,8 +80,10 @@ class XlaJitCompiledCpuFunction {
std::vector<const char*> arg_names_;
std::vector<const char*> result_names_;
// The backing data for the program shape.
std::unique_ptr<const xla::ProgramShape> program_shape_;
// The backing data for the program shape. The proto form of program shape is
// used because the program shape is serialized and embedded in the object
// file.
std::unique_ptr<const xla::ProgramShapeProto> program_shape_;
};
} // namespace tensorflow

View File

@ -116,7 +116,7 @@ TEST(XlaJitCompiledCpuFunction, Sum) {
// Check program shape.
using xla::ShapeUtil;
const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
const xla::ProgramShape* program_shape = function.ProgramShape();
const xla::ProgramShapeProto* program_shape = function.ProgramShape();
ASSERT_TRUE(program_shape != nullptr);
ASSERT_EQ(program_shape->parameters_size(), 2);
EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(0), s32));

View File

@ -226,12 +226,14 @@ cc_library(
"index_util.cc",
"layout_util.cc",
"primitive_util.cc",
"shape.cc",
"shape_util.cc",
],
hdrs = [
"index_util.h",
"layout_util.h",
"primitive_util.h",
"shape.h",
"shape_util.h",
],
visibility = ["//visibility:public"],
@ -254,6 +256,23 @@ cc_library(
],
)
tf_cc_test(
name = "shape_test",
srcs = ["shape_test.cc"],
deps = [
":shape_util",
":status_macros",
":test",
":test_helpers",
":types",
":util",
":xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
],
)
tf_cc_test(
name = "shape_util_test",
srcs = ["shape_util_test.cc"],

View File

@ -191,6 +191,7 @@ cc_library(
hdrs = ["xla_computation.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",

View File

@ -288,7 +288,8 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) {
HloComputationProto entry;
SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId());
TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), GetProgramShape(root_id));
TF_ASSIGN_OR_RETURN(ProgramShape program_shape, GetProgramShape(root_id));
*entry.mutable_program_shape() = program_shape.ToProto();
entry.set_root_id(root_id);
for (auto& instruction : instructions_) {
@ -2372,7 +2373,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
SetProtoIdAndName(&entry, StrCat(name_, "_compute_constant"), kNameSeparator,
GetNextId());
entry.set_root_id(root->id());
ProgramShape* program_shape = entry.mutable_program_shape();
ProgramShapeProto* program_shape = entry.mutable_program_shape();
*program_shape->mutable_result() = root->shape();
// We use std::set to keep the instruction ids in ascending order (which is

View File

@ -25,7 +25,7 @@ namespace xla {
StatusOr<ProgramShape> XlaComputation::GetProgramShape() const {
TF_RET_CHECK(proto_.has_host_program_shape());
return proto_.host_program_shape();
return ProgramShape(proto_.host_program_shape());
}
StatusOr<std::unique_ptr<HloSnapshot>> XlaComputation::Snapshot() const {

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"

View File

@ -487,12 +487,13 @@ StatusOr<CompiledXrtComputation*> LocalComputation::CompileForXrt(
xrt::XLAComputation c;
auto config = c.mutable_config();
auto shapes = config->mutable_program_shape();
ProgramShape shapes;
for (auto& shape : argument_shapes) {
*shapes->add_parameters() = shape;
*shapes.add_parameters() = shape;
}
TF_ASSIGN_OR_RETURN(*shapes->mutable_result(), GetReturnValueShape());
LayoutUtil::SetToDefaultLayout(shapes);
TF_ASSIGN_OR_RETURN(*shapes.mutable_result(), GetReturnValueShape());
LayoutUtil::SetToDefaultLayout(&shapes);
*config->mutable_program_shape() = shapes.ToProto();
auto snapshot = computation().Snapshot().ValueOrDie();
*c.mutable_hlo_snapshot() = *snapshot;

View File

@ -86,15 +86,15 @@ CompileOnlyService::CompileAheadOfTime(
Executable::DumpToDirectory(per_host_path, filename, hlo_snapshot));
}
const auto& program_shape = instance.computation.host_program_shape();
ExecutionOptions execution_options;
*execution_options.mutable_debug_options() = debug_options;
*execution_options.mutable_shape_with_output_layout() =
*instance.result_layout;
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(program_shape, instance.argument_layouts,
&execution_options));
CreateModuleConfig(
ProgramShape(instance.computation.host_program_shape()),
instance.argument_layouts, &execution_options));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModule> hlo_module,

View File

@ -205,7 +205,8 @@ message HloComputationProto {
repeated HloInstructionProto instructions = 2;
// The program shape (with layout) of this computation.
xla.ProgramShape program_shape = 4;
xla.ProgramShapeProto program_shape = 4;
// The id of this computation.
int64 id = 5;
@ -297,7 +298,7 @@ message HloModuleProto {
repeated HloComputationProto computations = 3;
// The host program shape (with layout) of the entry computation.
xla.ProgramShape host_program_shape = 4;
xla.ProgramShapeProto host_program_shape = 4;
// The id of this module.
int64 id = 5;

View File

@ -499,7 +499,7 @@ HloComputationProto HloComputation::ToProto() const {
proto.add_instructions()->Swap(&instruction_proto);
}
proto.set_root_id(root_instruction()->unique_id());
*proto.mutable_program_shape() = ComputeProgramShape();
*proto.mutable_program_shape() = ComputeProgramShape().ToProto();
return proto;
}

View File

@ -240,7 +240,7 @@ HloModuleProto HloModule::ToProto() const {
*proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
}
*proto.mutable_host_program_shape() =
entry_computation_layout().ComputeProgramShape();
entry_computation_layout().ComputeProgramShape().ToProto();
*proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
*proto.mutable_dynamic_parameter_binding() =
dynamic_parameter_binding().ToProto();
@ -371,7 +371,7 @@ StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
<< "No program shape found in the proto";
const auto& program_shape = module.host_program_shape();
HloModuleConfig module_config(program_shape);
HloModuleConfig module_config(ProgramShape{program_shape});
module_config.set_debug_options(debug_options);
// The module config is constructed with default layouts regardless of what is

View File

@ -145,7 +145,7 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
const ExecutableBuildOptions& build_options) {
const HloModuleProto& proto = computation.proto();
TF_RET_CHECK(proto.has_host_program_shape());
const ProgramShape& program_shape = proto.host_program_shape();
ProgramShape program_shape(proto.host_program_shape());
// Validate incoming layouts.
if (argument_layouts.size() != program_shape.parameters_size()) {

View File

@ -658,9 +658,9 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
// replica 0.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(request.computation().host_program_shape(),
replicated_arguments.front(),
request.execution_options()));
CreateModuleConfig(
ProgramShape{request.computation().host_program_shape()},
replicated_arguments.front(), request.execution_options()));
VLOG(3)
<< "ExecuteGraphParallel created HloModuleConfig computation layout: "
<< module_config->entry_computation_layout().ToString();
@ -824,7 +824,7 @@ Status Service::Compile(const CompileRequest* arg, CompileResponse* result) {
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(arg->computation().host_program_shape(),
CreateModuleConfig(ProgramShape{arg->computation().host_program_shape()},
argument_shapes, &arg->execution_options()));
VLOG(3) << "Compile created HloModuleConfig computation layout: "
<< module_config->entry_computation_layout().ToString();
@ -1072,7 +1072,7 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
"constant computation may not depend on any parameters.");
}
ProgramShape program_shape = arg->computation().host_program_shape();
ProgramShape program_shape(arg->computation().host_program_shape());
TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result()));
if (arg->has_output_layout()) {
TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(
@ -1116,7 +1116,7 @@ Status Service::GetComputationGraphStats(
return InvalidArgument("Program shape may not be empty.");
}
HloModuleConfig config(arg->computation().host_program_shape());
HloModuleConfig config(ProgramShape{arg->computation().host_program_shape()});
config.set_debug_options(arg->debug_options());
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
CreateModuleFromProto(arg->computation(), config));

View File

@ -0,0 +1,62 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/shape.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/shape_util.h"
namespace xla {
ProgramShape::ProgramShape(const ProgramShapeProto& program_shape_proto) {
for (const Shape& shape : program_shape_proto.parameters()) {
*add_parameters() = shape;
}
*mutable_result() = program_shape_proto.result();
for (const string& name : program_shape_proto.parameter_names()) {
add_parameter_names(name);
}
}
ProgramShapeProto ProgramShape::ToProto() const {
ProgramShapeProto proto;
for (const Shape& shape : parameters()) {
*proto.add_parameters() = shape;
}
*proto.mutable_result() = result();
for (const string& name : parameter_names()) {
proto.add_parameter_names(name);
}
return proto;
}
string ProgramShape::ToString() const {
std::vector<string> parameter_strings(parameters_size());
for (int i = 0; i < parameters_size(); ++i) {
parameter_strings[i] = absl::StrCat(
i < parameter_names_size() ? parameter_names(i) : "(unknown)", ": ",
ShapeUtil::HumanString(parameters(i)));
}
return absl::StrCat("(", absl::StrJoin(parameter_strings, ", "), ") -> ",
ShapeUtil::HumanString(result()));
}
std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape) {
out << program_shape.ToString() << "\n";
return out;
}
} // namespace xla

View File

@ -0,0 +1,108 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SHAPE_H_
#define TENSORFLOW_COMPILER_XLA_SHAPE_H_
#include <string>
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
// Shape of the parameters and output of an XLA computation. This is analogous
// to a traditional function signature.
class ProgramShape {
public:
ProgramShape() = default;
// Creates a ProgramShape from a ProgramShapeProto protobuf.
explicit ProgramShape(const ProgramShapeProto& program_shape_proto);
// Returns a proto representation of the object.
ProgramShapeProto ToProto() const;
string ToString() const;
// The following methods mirror the protobuf generated code interface for the
// message ProgramShapeProto. This enabled easy migration of this data
// structure from a proto to a proper C++ class.
// TODO(b/29771030): Replace or augment these methods with a more ergonomic
// interface.
// Methods for accessing and manipulating the Shape of the parameters.
int parameters_size() const { return parameters_.size(); }
const Shape& parameters(int index) const { return parameters_.at(index); }
Shape* mutable_parameters(int index) { return &parameters_.at(index); }
Shape* add_parameters() {
parameters_.emplace_back();
return &parameters_.back();
}
void clear_parameters() { parameters_.clear(); }
const std::vector<Shape>& parameters() const { return parameters_; }
std::vector<Shape>* mutable_parameters() { return &parameters_; }
// Methods for accessing and manipulating the Shape of the result.
const Shape& result() const { return result_; }
Shape* mutable_result() { return &result_; }
void clear_result() { result_.Clear(); }
// Methods for accessing and manipulating the names of the parameters.
int parameter_names_size() const { return parameter_names_.size(); }
const string& parameter_names(int index) const {
return parameter_names_.at(index);
}
void set_parameter_names(int index, const string& value) {
parameter_names_.at(index) = value;
}
string* mutable_parameter_names(int index) {
return &parameter_names_.at(index);
}
void add_parameter_names(const string& value) {
parameter_names_.push_back(value);
}
string* add_parameter_names() {
parameter_names_.push_back("");
return &parameter_names_.back();
}
void clear_parameter_names() { parameter_names_.clear(); }
const std::vector<string>& parameter_names() const {
return parameter_names_;
}
std::vector<string>* mutable_parameter_names() { return &parameter_names_; }
string ShortDebugString() const { return ToProto().ShortDebugString(); }
string DebugString() const { return ToProto().DebugString(); }
private:
// The shapes of the parameters of the computation represented by this object.
std::vector<Shape> parameters_;
// The names of the parameters of the computation represented by this object.
std::vector<string> parameter_names_;
// The shape of the result of the computation represented by this object.
Shape result_;
};
std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SHAPE_H_

View File

@ -0,0 +1,112 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/shape.h"
#include <numeric>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace {
TEST(ShapeTest, ProgramShapeToFromProto) {
ProgramShape program_shape;
*program_shape.add_parameters() = ShapeUtil::MakeShape(F32, {1, 2, 3});
*program_shape.add_parameters() = ShapeUtil::MakeTokenShape();
*program_shape.add_parameters() = ShapeUtil::MakeShape(S64, {});
*program_shape.add_parameters() = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(S32, {}),
ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()}),
ShapeUtil::MakeShape(F32, {42, 42})});
*program_shape.mutable_result() = ShapeUtil::MakeShape(F32, {7});
program_shape.add_parameter_names("foo");
program_shape.add_parameter_names("bar");
program_shape.add_parameter_names("baz");
program_shape.add_parameter_names("qux qux");
// Create a copy of the program shape by round-tripping through a proto.
ProgramShape program_shape_copy(program_shape.ToProto());
ASSERT_EQ(program_shape.parameters_size(),
program_shape_copy.parameters_size());
for (int i = 0; i < program_shape.parameters_size(); ++i) {
EXPECT_TRUE(ShapeUtil::Equal(program_shape.parameters(i),
program_shape_copy.parameters(i)));
}
EXPECT_TRUE(
ShapeUtil::Equal(program_shape.result(), program_shape_copy.result()));
ASSERT_EQ(program_shape.parameter_names_size(),
program_shape_copy.parameter_names_size());
for (int i = 0; i < program_shape.parameter_names_size(); ++i) {
EXPECT_EQ(program_shape.parameter_names(i),
program_shape_copy.parameter_names(i));
}
}
TEST(ShapeTest, ProgramShapeToString) {
Shape opaque = ShapeUtil::MakeOpaqueShape();
Shape token = ShapeUtil::MakeTokenShape();
Shape scalar = ShapeUtil::MakeShape(F32, {});
Shape matrix = ShapeUtil::MakeShape(U32, {1, 2});
Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1});
Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2});
Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token});
ProgramShape prog = ShapeUtil::MakeProgramShape(
{opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple);
EXPECT_EQ(
"((unknown): opaque[], "
"(unknown): f32[], "
"(unknown): u32[1,2], "
"(unknown): s32[3,4], "
"(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), "
"(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) "
"-> "
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
ShapeUtil::HumanString(prog));
prog.add_parameter_names("arg0");
prog.add_parameter_names("scalar");
prog.add_parameter_names("matrix");
prog.add_parameter_names("matrix2");
prog.add_parameter_names("tuple");
prog.add_parameter_names("nested_tuple");
EXPECT_EQ(
"(arg0: opaque[], "
"scalar: f32[], "
"matrix: u32[1,2], "
"matrix2: s32[3,4], "
"tuple: (opaque[], f32[], u32[1,2], s32[3,4]), "
"nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], "
"token[])) "
"-> "
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
ShapeUtil::HumanString(prog));
}
} // namespace
} // namespace xla

View File

@ -563,6 +563,20 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
HumanString(program_shape.result()));
}
/* static */ string ShapeUtil::HumanString(
const ProgramShapeProto& program_shape_proto) {
std::vector<string> parameters;
for (auto& shape : program_shape_proto.parameters()) {
const int i = parameters.size();
parameters.push_back(StrCat(i < program_shape_proto.parameter_names_size()
? program_shape_proto.parameter_names(i)
: "(unknown)",
": ", HumanString(shape)));
}
return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ",
HumanString(program_shape_proto.result()));
}
namespace {
// Parses shapes with simple recursive descent structure -- consumes from the
// front of s and passes that view recursively as required.

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@ -239,6 +240,7 @@ class ShapeUtil {
//
// (param_name: f32[42x12], ...) -> f32[24x42]
static string HumanString(const ProgramShape& program_shape);
static string HumanString(const ProgramShapeProto& program_shape_proto);
// Parses a ShapeUtil::HumanString-format shape string back into a shape
// object.

View File

@ -575,37 +575,6 @@ TEST(ShapeUtilTest, HumanString) {
"((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, "
"token[])",
ShapeUtil::HumanStringWithLayout(nested_tuple));
ProgramShape prog = ShapeUtil::MakeProgramShape(
{opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple);
EXPECT_EQ(
"((unknown): opaque[], "
"(unknown): f32[], "
"(unknown): u32[1,2], "
"(unknown): s32[3,4], "
"(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), "
"(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) "
"-> "
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
ShapeUtil::HumanString(prog));
prog.add_parameter_names("arg0");
prog.add_parameter_names("scalar");
prog.add_parameter_names("matrix");
prog.add_parameter_names("matrix2");
prog.add_parameter_names("tuple");
prog.add_parameter_names("nested_tuple");
EXPECT_EQ(
"(arg0: opaque[], "
"scalar: f32[], "
"matrix: u32[1,2], "
"matrix2: s32[3,4], "
"tuple: (opaque[], f32[], u32[1,2], s32[3,4]), "
"nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], "
"token[])) "
"-> "
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
ShapeUtil::HumanString(prog));
}
TEST(ShapeUtilTest, ForEachSubshapeArray) {

View File

@ -55,7 +55,8 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) {
client_->GetComputationShape(computation).ConsumeValueOrDie();
std::unique_ptr<ProgramShape> replayed_shape =
client_->GetComputationShape(replayed).ConsumeValueOrDie();
ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(),
replayed_shape->ToProto()));
// Run it.
Literal literal =
@ -87,7 +88,8 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
client_->GetComputationShape(computation).ConsumeValueOrDie();
std::unique_ptr<ProgramShape> replayed_shape =
client_->GetComputationShape(replayed).ConsumeValueOrDie();
ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(),
replayed_shape->ToProto()));
// Run it.
std::unique_ptr<GlobalData> x_data =
@ -133,7 +135,8 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) {
client_->GetComputationShape(computation).ConsumeValueOrDie();
std::unique_ptr<ProgramShape> replayed_shape =
client_->GetComputationShape(replayed).ConsumeValueOrDie();
ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(),
replayed_shape->ToProto()));
// Run it.
Literal literal =

View File

@ -183,7 +183,7 @@ message Shape {
// Shape of the parameters and output of a computation (like a traditional
// function signature).
message ProgramShape {
message ProgramShapeProto {
repeated Shape parameters = 1;
Shape result = 2;
repeated string parameter_names = 3;

View File

@ -174,11 +174,12 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) {
ctx->set_output(0, handle_output);
xla::LocalExecutable* executable = entry->get().get_executable();
xla::ProgramShape program_shape = executable->executable()
->module()
.config()
.entry_computation_layout()
.ComputeProgramShape();
xla::ProgramShapeProto program_shape = executable->executable()
->module()
.config()
.entry_computation_layout()
.ComputeProgramShape()
.ToProto();
Tensor program_shape_output(DT_STRING, TensorShape({1}));
program_shape_output.vec<string>()(0) = program_shape.SerializeAsString();
ctx->set_output(1, program_shape_output);

View File

@ -411,7 +411,7 @@ TEST(RawApiTest, CompileAndExecute) {
auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
xla::ProgramShape program_shape;
xla::ProgramShapeProto program_shape;
EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
EXPECT_EQ(program_shape.parameters_size(), 2);
}
@ -465,7 +465,7 @@ TEST(RawApiTest, CompileAndExecuteWithArgumentVector) {
auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
xla::ProgramShape program_shape;
xla::ProgramShapeProto program_shape;
EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
EXPECT_EQ(program_shape.parameters_size(), 2);
}
@ -510,7 +510,7 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) {
TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(),
{c_handle.program_shape}, {release}, &outputs));
xla::ProgramShape program_shape;
xla::ProgramShapeProto program_shape;
EXPECT_TRUE(program_shape.ParseFromString(outputs[0].vec<string>()(0)));
EXPECT_EQ(program_shape.parameters_size(), 1);
@ -520,7 +520,7 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) {
<< xla::ShapeUtil::HumanStringWithLayout(program_shape.result());
xla::ProgramShape xla_program_shape =
XlaCompiledProgramShape(xla_computation, *shapes);
XlaCompiledProgramShape(xla_computation, xla::ProgramShape(*shapes));
EXPECT_TRUE(xla::LayoutUtil::Equal(
xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(),
xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0})
@ -739,7 +739,7 @@ TEST(RawApiTest, CompileAndExecuteWithS64Argument) {
auto expected = xla::LiteralUtil::CreateR0<int64>(15123899);
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
xla::ProgramShape program_shape;
xla::ProgramShapeProto program_shape;
EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
EXPECT_EQ(program_shape.parameters_size(), 2);
EXPECT_TRUE(

View File

@ -36,11 +36,11 @@ message XLAComputationConfig {
tensorflow.tf2xla.HostComputeMetadata host_compute_metadata = 3;
// The arg/result shapes for the whole computation.
xla.ProgramShape program_shape = 4;
xla.ProgramShapeProto program_shape = 4;
// The arg/result shapes for each core of a model-parallel
// computation. per_core_args_and_result_shapes is optional for a
// single-core computation.
repeated xla.ProgramShape per_core_program_shape = 5;
repeated xla.ProgramShapeProto per_core_program_shape = 5;
// Describes how replicated computation instances should be assigned to
// devices. There are num_cores_per_replica computations, and each one will be
// sent and executed to the set of replica device numbers described in the