Replace ProgramShape proto with a C++ class.
Rename the protobuf message ProgramShape to ProgramShapeProto and create a new ProgramShape C++ class with an interface which mirrors the protobuf generated code interface. This CL is a step toward replacing Shape proto with a C++ class. ProgramShape needs to be migrated first because ProgramShape contains Shapes. PiperOrigin-RevId: 222435461
This commit is contained in:
parent
f6ce9fd485
commit
f22eec10b6
@ -164,7 +164,8 @@ string RewriteWithName(const string& name, string code,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generate methods for args (inputs).
|
// Generate methods for args (inputs).
|
||||||
Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
|
Status GenArgMethods(const tf2xla::Config& config,
|
||||||
|
const xla::ProgramShapeProto& ps,
|
||||||
const CompileResult& compile_result, string* methods) {
|
const CompileResult& compile_result, string* methods) {
|
||||||
size_t num_args = ps.parameters_size();
|
size_t num_args = ps.parameters_size();
|
||||||
if (config.feed_size() != num_args) {
|
if (config.feed_size() != num_args) {
|
||||||
@ -204,7 +205,7 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
|
|||||||
|
|
||||||
// Generate methods for results (outputs).
|
// Generate methods for results (outputs).
|
||||||
Status GenResultMethods(const tf2xla::Config& config,
|
Status GenResultMethods(const tf2xla::Config& config,
|
||||||
const xla::ProgramShape& ps, string* methods) {
|
const xla::ProgramShapeProto& ps, string* methods) {
|
||||||
if (ps.result().element_type() != xla::TUPLE) {
|
if (ps.result().element_type() != xla::TUPLE) {
|
||||||
// The XlaCompiler we use to build the xla computation always generates a
|
// The XlaCompiler we use to build the xla computation always generates a
|
||||||
// tuple result, and we rely on this to simplify code generation.
|
// tuple result, and we rely on this to simplify code generation.
|
||||||
@ -336,7 +337,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
|||||||
ExtractEntryParamBufferInfos(buffer_infos);
|
ExtractEntryParamBufferInfos(buffer_infos);
|
||||||
std::vector<BufferInfo> buffer_infos_for_temps =
|
std::vector<BufferInfo> buffer_infos_for_temps =
|
||||||
ExtractTempBufferInfos(buffer_infos);
|
ExtractTempBufferInfos(buffer_infos);
|
||||||
const xla::ProgramShape& ps = compile_result.program_shape;
|
const xla::ProgramShapeProto& ps = compile_result.program_shape;
|
||||||
string methods_arg, methods_result;
|
string methods_arg, methods_result;
|
||||||
TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
|
TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
|
||||||
TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
|
TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
|
||||||
@ -548,8 +549,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
|
static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
|
||||||
|
|
||||||
// Shape of the args and results.
|
// Shape of the args and results.
|
||||||
static const xla::ProgramShape* StaticProgramShape() {
|
static const xla::ProgramShapeProto* StaticProgramShape() {
|
||||||
static const xla::ProgramShape* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
|
static const xla::ProgramShapeProto* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
|
||||||
return kShape;
|
return kShape;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -615,11 +616,11 @@ static string CreateUniqueIdentifier(const CodegenOpts& opts,
|
|||||||
Status GenerateMetadata(const CodegenOpts& opts,
|
Status GenerateMetadata(const CodegenOpts& opts,
|
||||||
const CompileResult& compile_result,
|
const CompileResult& compile_result,
|
||||||
MetadataResult* metadata_result) {
|
MetadataResult* metadata_result) {
|
||||||
std::unique_ptr<xla::ProgramShape> program_shape;
|
std::unique_ptr<xla::ProgramShapeProto> program_shape;
|
||||||
|
|
||||||
if (opts.gen_program_shape) {
|
if (opts.gen_program_shape) {
|
||||||
program_shape =
|
program_shape =
|
||||||
absl::make_unique<xla::ProgramShape>(compile_result.program_shape);
|
absl::make_unique<xla::ProgramShapeProto>(compile_result.program_shape);
|
||||||
|
|
||||||
// The parameter names are currently meaningless, and redundant with the
|
// The parameter names are currently meaningless, and redundant with the
|
||||||
// rest of our metadata, so clear them out to avoid confusion and save
|
// rest of our metadata, so clear them out to avoid confusion and save
|
||||||
@ -631,8 +632,8 @@ Status GenerateMetadata(const CodegenOpts& opts,
|
|||||||
// a shim that evaluates to nullptr, which is what we want.
|
// a shim that evaluates to nullptr, which is what we want.
|
||||||
|
|
||||||
ProtobufToEmbed program_shape_protobuf{
|
ProtobufToEmbed program_shape_protobuf{
|
||||||
CreateUniqueIdentifier(opts, "ProgramShape"), "xla::ProgramShape",
|
CreateUniqueIdentifier(opts, "ProgramShapeProto"),
|
||||||
program_shape.get()};
|
"xla::ProgramShapeProto", program_shape.get()};
|
||||||
|
|
||||||
ProtobufToEmbed hlo_profile_printer_data_protobuf{
|
ProtobufToEmbed hlo_profile_printer_data_protobuf{
|
||||||
CreateUniqueIdentifier(opts, "HloProfilePrinterData"),
|
CreateUniqueIdentifier(opts, "HloProfilePrinterData"),
|
||||||
|
@ -57,7 +57,7 @@ struct MetadataResult {
|
|||||||
std::vector<string> header_variable_decls;
|
std::vector<string> header_variable_decls;
|
||||||
|
|
||||||
// program_shape_access_shim is a C++ expression that constructs the
|
// program_shape_access_shim is a C++ expression that constructs the
|
||||||
// xla::ProgramShape instance for the CompileResult passed to
|
// xla::ProgramShapeProto instance for the CompileResult passed to
|
||||||
// GenerateMetadata.
|
// GenerateMetadata.
|
||||||
string program_shape_access_shim;
|
string program_shape_access_shim;
|
||||||
|
|
||||||
|
@ -181,13 +181,15 @@ TEST(CodegenTest, Golden) {
|
|||||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
|
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
|
||||||
BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
|
BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
|
||||||
5, {}));
|
5, {}));
|
||||||
compile_result.program_shape = xla::ShapeUtil::MakeProgramShape(
|
compile_result.program_shape =
|
||||||
{
|
xla::ShapeUtil::MakeProgramShape(
|
||||||
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
|
{
|
||||||
xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
|
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
|
||||||
},
|
xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
|
||||||
xla::ShapeUtil::MakeTupleShape(
|
},
|
||||||
{xla::ShapeUtil::MakeShape(xla::U32, {5, 6})}));
|
xla::ShapeUtil::MakeTupleShape(
|
||||||
|
{xla::ShapeUtil::MakeShape(xla::U32, {5, 6})}))
|
||||||
|
.ToProto();
|
||||||
compile_result.entry_point = "entry_point";
|
compile_result.entry_point = "entry_point";
|
||||||
compile_result.pointer_size = 8;
|
compile_result.pointer_size = 8;
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ extern "C" void entry_point(
|
|||||||
void* result, const xla::ExecutableRunOptions* run_options,
|
void* result, const xla::ExecutableRunOptions* run_options,
|
||||||
const void** args, void** temps, tensorflow::int64* profile_counters);
|
const void** args, void** temps, tensorflow::int64* profile_counters);
|
||||||
|
|
||||||
extern "C" char __tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[];
|
extern "C" char __tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[];
|
||||||
|
|
||||||
|
|
||||||
namespace foo {
|
namespace foo {
|
||||||
@ -253,10 +253,10 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Shape of the args and results.
|
// Shape of the args and results.
|
||||||
static const xla::ProgramShape* StaticProgramShape() {
|
static const xla::ProgramShapeProto* StaticProgramShape() {
|
||||||
static const xla::ProgramShape* kShape = []() {
|
static const xla::ProgramShapeProto* kShape = []() {
|
||||||
xla::ProgramShape* proto = new xla::ProgramShape;
|
xla::ProgramShapeProto* proto = new xla::ProgramShapeProto;
|
||||||
proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[0], 52);
|
proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 52);
|
||||||
return proto;
|
return proto;
|
||||||
}();
|
}();
|
||||||
return kShape;
|
return kShape;
|
||||||
|
Binary file not shown.
@ -56,8 +56,8 @@ Status CompileXla(xla::CompileOnlyClient* client,
|
|||||||
return errors::Unknown("Couldn't get XLA program shape: ",
|
return errors::Unknown("Couldn't get XLA program shape: ",
|
||||||
pshape_or.status().error_message());
|
pshape_or.status().error_message());
|
||||||
}
|
}
|
||||||
compile_result->program_shape = *pshape_or.ValueOrDie();
|
compile_result->program_shape = pshape_or.ValueOrDie()->ToProto();
|
||||||
xla::ProgramShape* pshape = &compile_result->program_shape;
|
xla::ProgramShapeProto* pshape = &compile_result->program_shape;
|
||||||
std::vector<const xla::Shape*> arg_layouts;
|
std::vector<const xla::Shape*> arg_layouts;
|
||||||
arg_layouts.reserve(pshape->parameters_size());
|
arg_layouts.reserve(pshape->parameters_size());
|
||||||
for (int i = 0; i < pshape->parameters_size(); ++i) {
|
for (int i = 0; i < pshape->parameters_size(); ++i) {
|
||||||
|
@ -33,9 +33,9 @@ namespace tfcompile {
|
|||||||
struct CompileResult {
|
struct CompileResult {
|
||||||
// Contains object file and meta-info.
|
// Contains object file and meta-info.
|
||||||
std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot;
|
std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot;
|
||||||
xla::ProgramShape program_shape; // Static shape of args and results.
|
xla::ProgramShapeProto program_shape; // Static shape of args and results.
|
||||||
string entry_point; // Name of generated function.
|
string entry_point; // Name of generated function.
|
||||||
int pointer_size = 0; // Size of a pointer in bytes.
|
int pointer_size = 0; // Size of a pointer in bytes.
|
||||||
};
|
};
|
||||||
|
|
||||||
// CompileGraph compiles the graph_def into an object file containing a function
|
// CompileGraph compiles the graph_def into an object file containing a function
|
||||||
|
@ -526,7 +526,7 @@ TEST(TFCompileTest, ProgramShape) {
|
|||||||
|
|
||||||
// muladd has the program shape defined.
|
// muladd has the program shape defined.
|
||||||
MatMulAndAddComp muladd;
|
MatMulAndAddComp muladd;
|
||||||
const xla::ProgramShape* muladd_shape = muladd.ProgramShape();
|
const xla::ProgramShapeProto* muladd_shape = muladd.ProgramShape();
|
||||||
ASSERT_TRUE(muladd_shape != nullptr);
|
ASSERT_TRUE(muladd_shape != nullptr);
|
||||||
ASSERT_EQ(muladd_shape->parameters_size(), 2);
|
ASSERT_EQ(muladd_shape->parameters_size(), 2);
|
||||||
EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(0), f32_2x2));
|
EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(0), f32_2x2));
|
||||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
|||||||
// Forward-declare, rather than include, to reduce code size for users that
|
// Forward-declare, rather than include, to reduce code size for users that
|
||||||
// never use this functionality.
|
// never use this functionality.
|
||||||
namespace xla {
|
namespace xla {
|
||||||
class ProgramShape;
|
class ProgramShapeProto;
|
||||||
class HloProfilePrinterData;
|
class HloProfilePrinterData;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,7 +84,7 @@ class XlaCompiledCpuFunction {
|
|||||||
void set_result_names(const char** result_names) {
|
void set_result_names(const char** result_names) {
|
||||||
result_names_ = result_names;
|
result_names_ = result_names;
|
||||||
}
|
}
|
||||||
void set_program_shape(const xla::ProgramShape* program_shape) {
|
void set_program_shape(const xla::ProgramShapeProto* program_shape) {
|
||||||
program_shape_ = program_shape;
|
program_shape_ = program_shape;
|
||||||
}
|
}
|
||||||
const xla::HloProfilePrinterData* hlo_profile_printer_data() const {
|
const xla::HloProfilePrinterData* hlo_profile_printer_data() const {
|
||||||
@ -122,7 +122,7 @@ class XlaCompiledCpuFunction {
|
|||||||
const char** result_names_ = nullptr;
|
const char** result_names_ = nullptr;
|
||||||
|
|
||||||
// [Optional] Arg and result shapes.
|
// [Optional] Arg and result shapes.
|
||||||
const xla::ProgramShape* program_shape_ = nullptr;
|
const xla::ProgramShapeProto* program_shape_ = nullptr;
|
||||||
|
|
||||||
// [Optional] Profile printer data. Null if profiling is disabled.
|
// [Optional] Profile printer data. Null if profiling is disabled.
|
||||||
const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
|
const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
|
||||||
@ -264,7 +264,7 @@ class XlaCompiledCpuFunction {
|
|||||||
|
|
||||||
// Returns the shape of the args and results. May return nullptr if the
|
// Returns the shape of the args and results. May return nullptr if the
|
||||||
// program shape isn't available.
|
// program shape isn't available.
|
||||||
const xla::ProgramShape* ProgramShape() const { return program_shape_; }
|
const xla::ProgramShapeProto* ProgramShape() const { return program_shape_; }
|
||||||
|
|
||||||
bool hlo_profiling_enabled() const {
|
bool hlo_profiling_enabled() const {
|
||||||
return hlo_profile_printer_data_ != nullptr;
|
return hlo_profile_printer_data_ != nullptr;
|
||||||
@ -305,7 +305,7 @@ class XlaCompiledCpuFunction {
|
|||||||
// Optional metadata.
|
// Optional metadata.
|
||||||
const char** arg_names_ = nullptr;
|
const char** arg_names_ = nullptr;
|
||||||
const char** result_names_ = nullptr;
|
const char** result_names_ = nullptr;
|
||||||
const xla::ProgramShape* program_shape_ = nullptr;
|
const xla::ProgramShapeProto* program_shape_ = nullptr;
|
||||||
const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
|
const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -133,7 +133,8 @@ XlaJitCompiledCpuFunction::Compile(
|
|||||||
jit->executable_ = std::move(executable);
|
jit->executable_ = std::move(executable);
|
||||||
jit->buffer_infos_ = std::move(buffer_infos);
|
jit->buffer_infos_ = std::move(buffer_infos);
|
||||||
jit->arg_index_table_ = std::move(arg_index_table);
|
jit->arg_index_table_ = std::move(arg_index_table);
|
||||||
jit->program_shape_ = std::move(program_shape);
|
jit->program_shape_ =
|
||||||
|
absl::make_unique<xla::ProgramShapeProto>(program_shape->ToProto());
|
||||||
jit->static_data_.set_raw_function(raw_function);
|
jit->static_data_.set_raw_function(raw_function);
|
||||||
jit->static_data_.set_buffer_infos(jit->buffer_infos_.data());
|
jit->static_data_.set_buffer_infos(jit->buffer_infos_.data());
|
||||||
jit->static_data_.set_num_buffers(jit->buffer_infos_.size());
|
jit->static_data_.set_num_buffers(jit->buffer_infos_.size());
|
||||||
|
@ -80,8 +80,10 @@ class XlaJitCompiledCpuFunction {
|
|||||||
std::vector<const char*> arg_names_;
|
std::vector<const char*> arg_names_;
|
||||||
std::vector<const char*> result_names_;
|
std::vector<const char*> result_names_;
|
||||||
|
|
||||||
// The backing data for the program shape.
|
// The backing data for the program shape. The proto form of program shape is
|
||||||
std::unique_ptr<const xla::ProgramShape> program_shape_;
|
// used because the program shape is serialized and embedded in the object
|
||||||
|
// file.
|
||||||
|
std::unique_ptr<const xla::ProgramShapeProto> program_shape_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -116,7 +116,7 @@ TEST(XlaJitCompiledCpuFunction, Sum) {
|
|||||||
// Check program shape.
|
// Check program shape.
|
||||||
using xla::ShapeUtil;
|
using xla::ShapeUtil;
|
||||||
const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
|
const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
|
||||||
const xla::ProgramShape* program_shape = function.ProgramShape();
|
const xla::ProgramShapeProto* program_shape = function.ProgramShape();
|
||||||
ASSERT_TRUE(program_shape != nullptr);
|
ASSERT_TRUE(program_shape != nullptr);
|
||||||
ASSERT_EQ(program_shape->parameters_size(), 2);
|
ASSERT_EQ(program_shape->parameters_size(), 2);
|
||||||
EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(0), s32));
|
EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(0), s32));
|
||||||
|
@ -226,12 +226,14 @@ cc_library(
|
|||||||
"index_util.cc",
|
"index_util.cc",
|
||||||
"layout_util.cc",
|
"layout_util.cc",
|
||||||
"primitive_util.cc",
|
"primitive_util.cc",
|
||||||
|
"shape.cc",
|
||||||
"shape_util.cc",
|
"shape_util.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"index_util.h",
|
"index_util.h",
|
||||||
"layout_util.h",
|
"layout_util.h",
|
||||||
"primitive_util.h",
|
"primitive_util.h",
|
||||||
|
"shape.h",
|
||||||
"shape_util.h",
|
"shape_util.h",
|
||||||
],
|
],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
@ -254,6 +256,23 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "shape_test",
|
||||||
|
srcs = ["shape_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":shape_util",
|
||||||
|
":status_macros",
|
||||||
|
":test",
|
||||||
|
":test_helpers",
|
||||||
|
":types",
|
||||||
|
":util",
|
||||||
|
":xla_data_proto",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "shape_util_test",
|
name = "shape_util_test",
|
||||||
srcs = ["shape_util_test.cc"],
|
srcs = ["shape_util_test.cc"],
|
||||||
|
@ -191,6 +191,7 @@ cc_library(
|
|||||||
hdrs = ["xla_computation.h"],
|
hdrs = ["xla_computation.h"],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
|
@ -288,7 +288,8 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) {
|
|||||||
|
|
||||||
HloComputationProto entry;
|
HloComputationProto entry;
|
||||||
SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId());
|
SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId());
|
||||||
TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), GetProgramShape(root_id));
|
TF_ASSIGN_OR_RETURN(ProgramShape program_shape, GetProgramShape(root_id));
|
||||||
|
*entry.mutable_program_shape() = program_shape.ToProto();
|
||||||
entry.set_root_id(root_id);
|
entry.set_root_id(root_id);
|
||||||
|
|
||||||
for (auto& instruction : instructions_) {
|
for (auto& instruction : instructions_) {
|
||||||
@ -2372,7 +2373,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
|||||||
SetProtoIdAndName(&entry, StrCat(name_, "_compute_constant"), kNameSeparator,
|
SetProtoIdAndName(&entry, StrCat(name_, "_compute_constant"), kNameSeparator,
|
||||||
GetNextId());
|
GetNextId());
|
||||||
entry.set_root_id(root->id());
|
entry.set_root_id(root->id());
|
||||||
ProgramShape* program_shape = entry.mutable_program_shape();
|
ProgramShapeProto* program_shape = entry.mutable_program_shape();
|
||||||
*program_shape->mutable_result() = root->shape();
|
*program_shape->mutable_result() = root->shape();
|
||||||
|
|
||||||
// We use std::set to keep the instruction ids in ascending order (which is
|
// We use std::set to keep the instruction ids in ascending order (which is
|
||||||
|
@ -25,7 +25,7 @@ namespace xla {
|
|||||||
|
|
||||||
StatusOr<ProgramShape> XlaComputation::GetProgramShape() const {
|
StatusOr<ProgramShape> XlaComputation::GetProgramShape() const {
|
||||||
TF_RET_CHECK(proto_.has_host_program_shape());
|
TF_RET_CHECK(proto_.has_host_program_shape());
|
||||||
return proto_.host_program_shape();
|
return ProgramShape(proto_.host_program_shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<HloSnapshot>> XlaComputation::Snapshot() const {
|
StatusOr<std::unique_ptr<HloSnapshot>> XlaComputation::Snapshot() const {
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
#include "tensorflow/compiler/xla/status.h"
|
#include "tensorflow/compiler/xla/status.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
@ -487,12 +487,13 @@ StatusOr<CompiledXrtComputation*> LocalComputation::CompileForXrt(
|
|||||||
|
|
||||||
xrt::XLAComputation c;
|
xrt::XLAComputation c;
|
||||||
auto config = c.mutable_config();
|
auto config = c.mutable_config();
|
||||||
auto shapes = config->mutable_program_shape();
|
ProgramShape shapes;
|
||||||
for (auto& shape : argument_shapes) {
|
for (auto& shape : argument_shapes) {
|
||||||
*shapes->add_parameters() = shape;
|
*shapes.add_parameters() = shape;
|
||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(*shapes->mutable_result(), GetReturnValueShape());
|
TF_ASSIGN_OR_RETURN(*shapes.mutable_result(), GetReturnValueShape());
|
||||||
LayoutUtil::SetToDefaultLayout(shapes);
|
LayoutUtil::SetToDefaultLayout(&shapes);
|
||||||
|
*config->mutable_program_shape() = shapes.ToProto();
|
||||||
auto snapshot = computation().Snapshot().ValueOrDie();
|
auto snapshot = computation().Snapshot().ValueOrDie();
|
||||||
*c.mutable_hlo_snapshot() = *snapshot;
|
*c.mutable_hlo_snapshot() = *snapshot;
|
||||||
|
|
||||||
|
@ -86,15 +86,15 @@ CompileOnlyService::CompileAheadOfTime(
|
|||||||
Executable::DumpToDirectory(per_host_path, filename, hlo_snapshot));
|
Executable::DumpToDirectory(per_host_path, filename, hlo_snapshot));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& program_shape = instance.computation.host_program_shape();
|
|
||||||
ExecutionOptions execution_options;
|
ExecutionOptions execution_options;
|
||||||
*execution_options.mutable_debug_options() = debug_options;
|
*execution_options.mutable_debug_options() = debug_options;
|
||||||
*execution_options.mutable_shape_with_output_layout() =
|
*execution_options.mutable_shape_with_output_layout() =
|
||||||
*instance.result_layout;
|
*instance.result_layout;
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<HloModuleConfig> module_config,
|
std::unique_ptr<HloModuleConfig> module_config,
|
||||||
CreateModuleConfig(program_shape, instance.argument_layouts,
|
CreateModuleConfig(
|
||||||
&execution_options));
|
ProgramShape(instance.computation.host_program_shape()),
|
||||||
|
instance.argument_layouts, &execution_options));
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<HloModule> hlo_module,
|
std::unique_ptr<HloModule> hlo_module,
|
||||||
|
@ -205,7 +205,8 @@ message HloComputationProto {
|
|||||||
repeated HloInstructionProto instructions = 2;
|
repeated HloInstructionProto instructions = 2;
|
||||||
|
|
||||||
// The program shape (with layout) of this computation.
|
// The program shape (with layout) of this computation.
|
||||||
xla.ProgramShape program_shape = 4;
|
|
||||||
|
xla.ProgramShapeProto program_shape = 4;
|
||||||
|
|
||||||
// The id of this computation.
|
// The id of this computation.
|
||||||
int64 id = 5;
|
int64 id = 5;
|
||||||
@ -297,7 +298,7 @@ message HloModuleProto {
|
|||||||
repeated HloComputationProto computations = 3;
|
repeated HloComputationProto computations = 3;
|
||||||
|
|
||||||
// The host program shape (with layout) of the entry computation.
|
// The host program shape (with layout) of the entry computation.
|
||||||
xla.ProgramShape host_program_shape = 4;
|
xla.ProgramShapeProto host_program_shape = 4;
|
||||||
|
|
||||||
// The id of this module.
|
// The id of this module.
|
||||||
int64 id = 5;
|
int64 id = 5;
|
||||||
|
@ -499,7 +499,7 @@ HloComputationProto HloComputation::ToProto() const {
|
|||||||
proto.add_instructions()->Swap(&instruction_proto);
|
proto.add_instructions()->Swap(&instruction_proto);
|
||||||
}
|
}
|
||||||
proto.set_root_id(root_instruction()->unique_id());
|
proto.set_root_id(root_instruction()->unique_id());
|
||||||
*proto.mutable_program_shape() = ComputeProgramShape();
|
*proto.mutable_program_shape() = ComputeProgramShape().ToProto();
|
||||||
return proto;
|
return proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -240,7 +240,7 @@ HloModuleProto HloModule::ToProto() const {
|
|||||||
*proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
|
*proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
|
||||||
}
|
}
|
||||||
*proto.mutable_host_program_shape() =
|
*proto.mutable_host_program_shape() =
|
||||||
entry_computation_layout().ComputeProgramShape();
|
entry_computation_layout().ComputeProgramShape().ToProto();
|
||||||
*proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
|
*proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
|
||||||
*proto.mutable_dynamic_parameter_binding() =
|
*proto.mutable_dynamic_parameter_binding() =
|
||||||
dynamic_parameter_binding().ToProto();
|
dynamic_parameter_binding().ToProto();
|
||||||
@ -371,7 +371,7 @@ StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
|
|||||||
<< "No program shape found in the proto";
|
<< "No program shape found in the proto";
|
||||||
const auto& program_shape = module.host_program_shape();
|
const auto& program_shape = module.host_program_shape();
|
||||||
|
|
||||||
HloModuleConfig module_config(program_shape);
|
HloModuleConfig module_config(ProgramShape{program_shape});
|
||||||
module_config.set_debug_options(debug_options);
|
module_config.set_debug_options(debug_options);
|
||||||
|
|
||||||
// The module config is constructed with default layouts regardless of what is
|
// The module config is constructed with default layouts regardless of what is
|
||||||
|
@ -145,7 +145,7 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
|
|||||||
const ExecutableBuildOptions& build_options) {
|
const ExecutableBuildOptions& build_options) {
|
||||||
const HloModuleProto& proto = computation.proto();
|
const HloModuleProto& proto = computation.proto();
|
||||||
TF_RET_CHECK(proto.has_host_program_shape());
|
TF_RET_CHECK(proto.has_host_program_shape());
|
||||||
const ProgramShape& program_shape = proto.host_program_shape();
|
ProgramShape program_shape(proto.host_program_shape());
|
||||||
|
|
||||||
// Validate incoming layouts.
|
// Validate incoming layouts.
|
||||||
if (argument_layouts.size() != program_shape.parameters_size()) {
|
if (argument_layouts.size() != program_shape.parameters_size()) {
|
||||||
|
@ -658,9 +658,9 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
|
|||||||
// replica 0.
|
// replica 0.
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<HloModuleConfig> module_config,
|
std::unique_ptr<HloModuleConfig> module_config,
|
||||||
CreateModuleConfig(request.computation().host_program_shape(),
|
CreateModuleConfig(
|
||||||
replicated_arguments.front(),
|
ProgramShape{request.computation().host_program_shape()},
|
||||||
request.execution_options()));
|
replicated_arguments.front(), request.execution_options()));
|
||||||
VLOG(3)
|
VLOG(3)
|
||||||
<< "ExecuteGraphParallel created HloModuleConfig computation layout: "
|
<< "ExecuteGraphParallel created HloModuleConfig computation layout: "
|
||||||
<< module_config->entry_computation_layout().ToString();
|
<< module_config->entry_computation_layout().ToString();
|
||||||
@ -824,7 +824,7 @@ Status Service::Compile(const CompileRequest* arg, CompileResponse* result) {
|
|||||||
[](const Shape& shape) { return &shape; });
|
[](const Shape& shape) { return &shape; });
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<HloModuleConfig> module_config,
|
std::unique_ptr<HloModuleConfig> module_config,
|
||||||
CreateModuleConfig(arg->computation().host_program_shape(),
|
CreateModuleConfig(ProgramShape{arg->computation().host_program_shape()},
|
||||||
argument_shapes, &arg->execution_options()));
|
argument_shapes, &arg->execution_options()));
|
||||||
VLOG(3) << "Compile created HloModuleConfig computation layout: "
|
VLOG(3) << "Compile created HloModuleConfig computation layout: "
|
||||||
<< module_config->entry_computation_layout().ToString();
|
<< module_config->entry_computation_layout().ToString();
|
||||||
@ -1072,7 +1072,7 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
|
|||||||
"constant computation may not depend on any parameters.");
|
"constant computation may not depend on any parameters.");
|
||||||
}
|
}
|
||||||
|
|
||||||
ProgramShape program_shape = arg->computation().host_program_shape();
|
ProgramShape program_shape(arg->computation().host_program_shape());
|
||||||
TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result()));
|
TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result()));
|
||||||
if (arg->has_output_layout()) {
|
if (arg->has_output_layout()) {
|
||||||
TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(
|
TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(
|
||||||
@ -1116,7 +1116,7 @@ Status Service::GetComputationGraphStats(
|
|||||||
return InvalidArgument("Program shape may not be empty.");
|
return InvalidArgument("Program shape may not be empty.");
|
||||||
}
|
}
|
||||||
|
|
||||||
HloModuleConfig config(arg->computation().host_program_shape());
|
HloModuleConfig config(ProgramShape{arg->computation().host_program_shape()});
|
||||||
config.set_debug_options(arg->debug_options());
|
config.set_debug_options(arg->debug_options());
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
|
||||||
CreateModuleFromProto(arg->computation(), config));
|
CreateModuleFromProto(arg->computation(), config));
|
||||||
|
62
tensorflow/compiler/xla/shape.cc
Normal file
62
tensorflow/compiler/xla/shape.cc
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/strings/str_join.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
ProgramShape::ProgramShape(const ProgramShapeProto& program_shape_proto) {
|
||||||
|
for (const Shape& shape : program_shape_proto.parameters()) {
|
||||||
|
*add_parameters() = shape;
|
||||||
|
}
|
||||||
|
*mutable_result() = program_shape_proto.result();
|
||||||
|
for (const string& name : program_shape_proto.parameter_names()) {
|
||||||
|
add_parameter_names(name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ProgramShapeProto ProgramShape::ToProto() const {
|
||||||
|
ProgramShapeProto proto;
|
||||||
|
for (const Shape& shape : parameters()) {
|
||||||
|
*proto.add_parameters() = shape;
|
||||||
|
}
|
||||||
|
*proto.mutable_result() = result();
|
||||||
|
for (const string& name : parameter_names()) {
|
||||||
|
proto.add_parameter_names(name);
|
||||||
|
}
|
||||||
|
return proto;
|
||||||
|
}
|
||||||
|
|
||||||
|
string ProgramShape::ToString() const {
|
||||||
|
std::vector<string> parameter_strings(parameters_size());
|
||||||
|
for (int i = 0; i < parameters_size(); ++i) {
|
||||||
|
parameter_strings[i] = absl::StrCat(
|
||||||
|
i < parameter_names_size() ? parameter_names(i) : "(unknown)", ": ",
|
||||||
|
ShapeUtil::HumanString(parameters(i)));
|
||||||
|
}
|
||||||
|
return absl::StrCat("(", absl::StrJoin(parameter_strings, ", "), ") -> ",
|
||||||
|
ShapeUtil::HumanString(result()));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape) {
|
||||||
|
out << program_shape.ToString() << "\n";
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
108
tensorflow/compiler/xla/shape.h
Normal file
108
tensorflow/compiler/xla/shape.h
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_SHAPE_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_SHAPE_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
// Shape of the parameters and output of an XLA computation. This is analogous
|
||||||
|
// to a traditional function signature.
|
||||||
|
class ProgramShape {
|
||||||
|
public:
|
||||||
|
ProgramShape() = default;
|
||||||
|
|
||||||
|
// Creates a ProgramShape from a ProgramShapeProto protobuf.
|
||||||
|
explicit ProgramShape(const ProgramShapeProto& program_shape_proto);
|
||||||
|
|
||||||
|
// Returns a proto representation of the object.
|
||||||
|
ProgramShapeProto ToProto() const;
|
||||||
|
|
||||||
|
string ToString() const;
|
||||||
|
|
||||||
|
// The following methods mirror the protobuf generated code interface for the
|
||||||
|
// message ProgramShapeProto. This enabled easy migration of this data
|
||||||
|
// structure from a proto to a proper C++ class.
|
||||||
|
// TODO(b/29771030): Replace or augment these methods with a more ergonomic
|
||||||
|
// interface.
|
||||||
|
|
||||||
|
// Methods for accessing and manipulating the Shape of the parameters.
|
||||||
|
int parameters_size() const { return parameters_.size(); }
|
||||||
|
const Shape& parameters(int index) const { return parameters_.at(index); }
|
||||||
|
Shape* mutable_parameters(int index) { return ¶meters_.at(index); }
|
||||||
|
Shape* add_parameters() {
|
||||||
|
parameters_.emplace_back();
|
||||||
|
return ¶meters_.back();
|
||||||
|
}
|
||||||
|
void clear_parameters() { parameters_.clear(); }
|
||||||
|
const std::vector<Shape>& parameters() const { return parameters_; }
|
||||||
|
std::vector<Shape>* mutable_parameters() { return ¶meters_; }
|
||||||
|
|
||||||
|
// Methods for accessing and manipulating the Shape of the result.
|
||||||
|
const Shape& result() const { return result_; }
|
||||||
|
Shape* mutable_result() { return &result_; }
|
||||||
|
void clear_result() { result_.Clear(); }
|
||||||
|
|
||||||
|
// Methods for accessing and manipulating the names of the parameters.
|
||||||
|
int parameter_names_size() const { return parameter_names_.size(); }
|
||||||
|
const string& parameter_names(int index) const {
|
||||||
|
return parameter_names_.at(index);
|
||||||
|
}
|
||||||
|
void set_parameter_names(int index, const string& value) {
|
||||||
|
parameter_names_.at(index) = value;
|
||||||
|
}
|
||||||
|
string* mutable_parameter_names(int index) {
|
||||||
|
return ¶meter_names_.at(index);
|
||||||
|
}
|
||||||
|
void add_parameter_names(const string& value) {
|
||||||
|
parameter_names_.push_back(value);
|
||||||
|
}
|
||||||
|
string* add_parameter_names() {
|
||||||
|
parameter_names_.push_back("");
|
||||||
|
return ¶meter_names_.back();
|
||||||
|
}
|
||||||
|
void clear_parameter_names() { parameter_names_.clear(); }
|
||||||
|
const std::vector<string>& parameter_names() const {
|
||||||
|
return parameter_names_;
|
||||||
|
}
|
||||||
|
std::vector<string>* mutable_parameter_names() { return ¶meter_names_; }
|
||||||
|
|
||||||
|
string ShortDebugString() const { return ToProto().ShortDebugString(); }
|
||||||
|
string DebugString() const { return ToProto().DebugString(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
// The shapes of the parameters of the computation represented by this object.
|
||||||
|
std::vector<Shape> parameters_;
|
||||||
|
|
||||||
|
// The names of the parameters of the computation represented by this object.
|
||||||
|
std::vector<string> parameter_names_;
|
||||||
|
|
||||||
|
// The shape of the result of the computation represented by this object.
|
||||||
|
Shape result_;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape);
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_SHAPE_H_
|
112
tensorflow/compiler/xla/shape_test.cc
Normal file
112
tensorflow/compiler/xla/shape_test.cc
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/strings/str_join.h"
|
||||||
|
#include "tensorflow/compiler/xla/layout_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
|
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||||
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(ShapeTest, ProgramShapeToFromProto) {
|
||||||
|
ProgramShape program_shape;
|
||||||
|
*program_shape.add_parameters() = ShapeUtil::MakeShape(F32, {1, 2, 3});
|
||||||
|
*program_shape.add_parameters() = ShapeUtil::MakeTokenShape();
|
||||||
|
*program_shape.add_parameters() = ShapeUtil::MakeShape(S64, {});
|
||||||
|
*program_shape.add_parameters() = ShapeUtil::MakeTupleShape(
|
||||||
|
{ShapeUtil::MakeShape(S32, {}),
|
||||||
|
ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()}),
|
||||||
|
ShapeUtil::MakeShape(F32, {42, 42})});
|
||||||
|
|
||||||
|
*program_shape.mutable_result() = ShapeUtil::MakeShape(F32, {7});
|
||||||
|
|
||||||
|
program_shape.add_parameter_names("foo");
|
||||||
|
program_shape.add_parameter_names("bar");
|
||||||
|
program_shape.add_parameter_names("baz");
|
||||||
|
program_shape.add_parameter_names("qux qux");
|
||||||
|
|
||||||
|
// Create a copy of the program shape by round-tripping through a proto.
|
||||||
|
ProgramShape program_shape_copy(program_shape.ToProto());
|
||||||
|
ASSERT_EQ(program_shape.parameters_size(),
|
||||||
|
program_shape_copy.parameters_size());
|
||||||
|
for (int i = 0; i < program_shape.parameters_size(); ++i) {
|
||||||
|
EXPECT_TRUE(ShapeUtil::Equal(program_shape.parameters(i),
|
||||||
|
program_shape_copy.parameters(i)));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_TRUE(
|
||||||
|
ShapeUtil::Equal(program_shape.result(), program_shape_copy.result()));
|
||||||
|
|
||||||
|
ASSERT_EQ(program_shape.parameter_names_size(),
|
||||||
|
program_shape_copy.parameter_names_size());
|
||||||
|
for (int i = 0; i < program_shape.parameter_names_size(); ++i) {
|
||||||
|
EXPECT_EQ(program_shape.parameter_names(i),
|
||||||
|
program_shape_copy.parameter_names(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ShapeTest, ProgramShapeToString) {
|
||||||
|
Shape opaque = ShapeUtil::MakeOpaqueShape();
|
||||||
|
Shape token = ShapeUtil::MakeTokenShape();
|
||||||
|
Shape scalar = ShapeUtil::MakeShape(F32, {});
|
||||||
|
Shape matrix = ShapeUtil::MakeShape(U32, {1, 2});
|
||||||
|
Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1});
|
||||||
|
Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2});
|
||||||
|
Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token});
|
||||||
|
|
||||||
|
ProgramShape prog = ShapeUtil::MakeProgramShape(
|
||||||
|
{opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple);
|
||||||
|
EXPECT_EQ(
|
||||||
|
"((unknown): opaque[], "
|
||||||
|
"(unknown): f32[], "
|
||||||
|
"(unknown): u32[1,2], "
|
||||||
|
"(unknown): s32[3,4], "
|
||||||
|
"(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), "
|
||||||
|
"(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) "
|
||||||
|
"-> "
|
||||||
|
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
|
||||||
|
ShapeUtil::HumanString(prog));
|
||||||
|
|
||||||
|
prog.add_parameter_names("arg0");
|
||||||
|
prog.add_parameter_names("scalar");
|
||||||
|
prog.add_parameter_names("matrix");
|
||||||
|
prog.add_parameter_names("matrix2");
|
||||||
|
prog.add_parameter_names("tuple");
|
||||||
|
prog.add_parameter_names("nested_tuple");
|
||||||
|
EXPECT_EQ(
|
||||||
|
"(arg0: opaque[], "
|
||||||
|
"scalar: f32[], "
|
||||||
|
"matrix: u32[1,2], "
|
||||||
|
"matrix2: s32[3,4], "
|
||||||
|
"tuple: (opaque[], f32[], u32[1,2], s32[3,4]), "
|
||||||
|
"nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], "
|
||||||
|
"token[])) "
|
||||||
|
"-> "
|
||||||
|
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
|
||||||
|
ShapeUtil::HumanString(prog));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace xla
|
@ -563,6 +563,20 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
|
|||||||
HumanString(program_shape.result()));
|
HumanString(program_shape.result()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */ string ShapeUtil::HumanString(
|
||||||
|
const ProgramShapeProto& program_shape_proto) {
|
||||||
|
std::vector<string> parameters;
|
||||||
|
for (auto& shape : program_shape_proto.parameters()) {
|
||||||
|
const int i = parameters.size();
|
||||||
|
parameters.push_back(StrCat(i < program_shape_proto.parameter_names_size()
|
||||||
|
? program_shape_proto.parameter_names(i)
|
||||||
|
: "(unknown)",
|
||||||
|
": ", HumanString(shape)));
|
||||||
|
}
|
||||||
|
return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ",
|
||||||
|
HumanString(program_shape_proto.result()));
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Parses shapes with simple recursive descent structure -- consumes from the
|
// Parses shapes with simple recursive descent structure -- consumes from the
|
||||||
// front of s and passes that view recursively as required.
|
// front of s and passes that view recursively as required.
|
||||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/layout_util.h"
|
#include "tensorflow/compiler/xla/layout_util.h"
|
||||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
@ -239,6 +240,7 @@ class ShapeUtil {
|
|||||||
//
|
//
|
||||||
// (param_name: f32[42x12], ...) -> f32[24x42]
|
// (param_name: f32[42x12], ...) -> f32[24x42]
|
||||||
static string HumanString(const ProgramShape& program_shape);
|
static string HumanString(const ProgramShape& program_shape);
|
||||||
|
static string HumanString(const ProgramShapeProto& program_shape_proto);
|
||||||
|
|
||||||
// Parses a ShapeUtil::HumanString-format shape string back into a shape
|
// Parses a ShapeUtil::HumanString-format shape string back into a shape
|
||||||
// object.
|
// object.
|
||||||
|
@ -575,37 +575,6 @@ TEST(ShapeUtilTest, HumanString) {
|
|||||||
"((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, "
|
"((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, "
|
||||||
"token[])",
|
"token[])",
|
||||||
ShapeUtil::HumanStringWithLayout(nested_tuple));
|
ShapeUtil::HumanStringWithLayout(nested_tuple));
|
||||||
|
|
||||||
ProgramShape prog = ShapeUtil::MakeProgramShape(
|
|
||||||
{opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple);
|
|
||||||
EXPECT_EQ(
|
|
||||||
"((unknown): opaque[], "
|
|
||||||
"(unknown): f32[], "
|
|
||||||
"(unknown): u32[1,2], "
|
|
||||||
"(unknown): s32[3,4], "
|
|
||||||
"(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), "
|
|
||||||
"(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) "
|
|
||||||
"-> "
|
|
||||||
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
|
|
||||||
ShapeUtil::HumanString(prog));
|
|
||||||
|
|
||||||
prog.add_parameter_names("arg0");
|
|
||||||
prog.add_parameter_names("scalar");
|
|
||||||
prog.add_parameter_names("matrix");
|
|
||||||
prog.add_parameter_names("matrix2");
|
|
||||||
prog.add_parameter_names("tuple");
|
|
||||||
prog.add_parameter_names("nested_tuple");
|
|
||||||
EXPECT_EQ(
|
|
||||||
"(arg0: opaque[], "
|
|
||||||
"scalar: f32[], "
|
|
||||||
"matrix: u32[1,2], "
|
|
||||||
"matrix2: s32[3,4], "
|
|
||||||
"tuple: (opaque[], f32[], u32[1,2], s32[3,4]), "
|
|
||||||
"nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], "
|
|
||||||
"token[])) "
|
|
||||||
"-> "
|
|
||||||
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
|
|
||||||
ShapeUtil::HumanString(prog));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ShapeUtilTest, ForEachSubshapeArray) {
|
TEST(ShapeUtilTest, ForEachSubshapeArray) {
|
||||||
|
@ -55,7 +55,8 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) {
|
|||||||
client_->GetComputationShape(computation).ConsumeValueOrDie();
|
client_->GetComputationShape(computation).ConsumeValueOrDie();
|
||||||
std::unique_ptr<ProgramShape> replayed_shape =
|
std::unique_ptr<ProgramShape> replayed_shape =
|
||||||
client_->GetComputationShape(replayed).ConsumeValueOrDie();
|
client_->GetComputationShape(replayed).ConsumeValueOrDie();
|
||||||
ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
|
ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(),
|
||||||
|
replayed_shape->ToProto()));
|
||||||
|
|
||||||
// Run it.
|
// Run it.
|
||||||
Literal literal =
|
Literal literal =
|
||||||
@ -87,7 +88,8 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
|
|||||||
client_->GetComputationShape(computation).ConsumeValueOrDie();
|
client_->GetComputationShape(computation).ConsumeValueOrDie();
|
||||||
std::unique_ptr<ProgramShape> replayed_shape =
|
std::unique_ptr<ProgramShape> replayed_shape =
|
||||||
client_->GetComputationShape(replayed).ConsumeValueOrDie();
|
client_->GetComputationShape(replayed).ConsumeValueOrDie();
|
||||||
ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
|
ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(),
|
||||||
|
replayed_shape->ToProto()));
|
||||||
|
|
||||||
// Run it.
|
// Run it.
|
||||||
std::unique_ptr<GlobalData> x_data =
|
std::unique_ptr<GlobalData> x_data =
|
||||||
@ -133,7 +135,8 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) {
|
|||||||
client_->GetComputationShape(computation).ConsumeValueOrDie();
|
client_->GetComputationShape(computation).ConsumeValueOrDie();
|
||||||
std::unique_ptr<ProgramShape> replayed_shape =
|
std::unique_ptr<ProgramShape> replayed_shape =
|
||||||
client_->GetComputationShape(replayed).ConsumeValueOrDie();
|
client_->GetComputationShape(replayed).ConsumeValueOrDie();
|
||||||
ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
|
ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(),
|
||||||
|
replayed_shape->ToProto()));
|
||||||
|
|
||||||
// Run it.
|
// Run it.
|
||||||
Literal literal =
|
Literal literal =
|
||||||
|
@ -183,7 +183,7 @@ message Shape {
|
|||||||
|
|
||||||
// Shape of the parameters and output of a computation (like a traditional
|
// Shape of the parameters and output of a computation (like a traditional
|
||||||
// function signature).
|
// function signature).
|
||||||
message ProgramShape {
|
message ProgramShapeProto {
|
||||||
repeated Shape parameters = 1;
|
repeated Shape parameters = 1;
|
||||||
Shape result = 2;
|
Shape result = 2;
|
||||||
repeated string parameter_names = 3;
|
repeated string parameter_names = 3;
|
||||||
|
@ -174,11 +174,12 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) {
|
|||||||
ctx->set_output(0, handle_output);
|
ctx->set_output(0, handle_output);
|
||||||
|
|
||||||
xla::LocalExecutable* executable = entry->get().get_executable();
|
xla::LocalExecutable* executable = entry->get().get_executable();
|
||||||
xla::ProgramShape program_shape = executable->executable()
|
xla::ProgramShapeProto program_shape = executable->executable()
|
||||||
->module()
|
->module()
|
||||||
.config()
|
.config()
|
||||||
.entry_computation_layout()
|
.entry_computation_layout()
|
||||||
.ComputeProgramShape();
|
.ComputeProgramShape()
|
||||||
|
.ToProto();
|
||||||
Tensor program_shape_output(DT_STRING, TensorShape({1}));
|
Tensor program_shape_output(DT_STRING, TensorShape({1}));
|
||||||
program_shape_output.vec<string>()(0) = program_shape.SerializeAsString();
|
program_shape_output.vec<string>()(0) = program_shape.SerializeAsString();
|
||||||
ctx->set_output(1, program_shape_output);
|
ctx->set_output(1, program_shape_output);
|
||||||
|
@ -411,7 +411,7 @@ TEST(RawApiTest, CompileAndExecute) {
|
|||||||
auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
|
auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
|
||||||
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
|
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
|
||||||
|
|
||||||
xla::ProgramShape program_shape;
|
xla::ProgramShapeProto program_shape;
|
||||||
EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
|
EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
|
||||||
EXPECT_EQ(program_shape.parameters_size(), 2);
|
EXPECT_EQ(program_shape.parameters_size(), 2);
|
||||||
}
|
}
|
||||||
@ -465,7 +465,7 @@ TEST(RawApiTest, CompileAndExecuteWithArgumentVector) {
|
|||||||
auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
|
auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
|
||||||
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
|
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
|
||||||
|
|
||||||
xla::ProgramShape program_shape;
|
xla::ProgramShapeProto program_shape;
|
||||||
EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
|
EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
|
||||||
EXPECT_EQ(program_shape.parameters_size(), 2);
|
EXPECT_EQ(program_shape.parameters_size(), 2);
|
||||||
}
|
}
|
||||||
@ -510,7 +510,7 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) {
|
|||||||
TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(),
|
TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(),
|
||||||
{c_handle.program_shape}, {release}, &outputs));
|
{c_handle.program_shape}, {release}, &outputs));
|
||||||
|
|
||||||
xla::ProgramShape program_shape;
|
xla::ProgramShapeProto program_shape;
|
||||||
EXPECT_TRUE(program_shape.ParseFromString(outputs[0].vec<string>()(0)));
|
EXPECT_TRUE(program_shape.ParseFromString(outputs[0].vec<string>()(0)));
|
||||||
EXPECT_EQ(program_shape.parameters_size(), 1);
|
EXPECT_EQ(program_shape.parameters_size(), 1);
|
||||||
|
|
||||||
@ -520,7 +520,7 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) {
|
|||||||
<< xla::ShapeUtil::HumanStringWithLayout(program_shape.result());
|
<< xla::ShapeUtil::HumanStringWithLayout(program_shape.result());
|
||||||
|
|
||||||
xla::ProgramShape xla_program_shape =
|
xla::ProgramShape xla_program_shape =
|
||||||
XlaCompiledProgramShape(xla_computation, *shapes);
|
XlaCompiledProgramShape(xla_computation, xla::ProgramShape(*shapes));
|
||||||
EXPECT_TRUE(xla::LayoutUtil::Equal(
|
EXPECT_TRUE(xla::LayoutUtil::Equal(
|
||||||
xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(),
|
xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(),
|
||||||
xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0})
|
xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0})
|
||||||
@ -739,7 +739,7 @@ TEST(RawApiTest, CompileAndExecuteWithS64Argument) {
|
|||||||
auto expected = xla::LiteralUtil::CreateR0<int64>(15123899);
|
auto expected = xla::LiteralUtil::CreateR0<int64>(15123899);
|
||||||
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
|
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
|
||||||
|
|
||||||
xla::ProgramShape program_shape;
|
xla::ProgramShapeProto program_shape;
|
||||||
EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
|
EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
|
||||||
EXPECT_EQ(program_shape.parameters_size(), 2);
|
EXPECT_EQ(program_shape.parameters_size(), 2);
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
|
@ -36,11 +36,11 @@ message XLAComputationConfig {
|
|||||||
tensorflow.tf2xla.HostComputeMetadata host_compute_metadata = 3;
|
tensorflow.tf2xla.HostComputeMetadata host_compute_metadata = 3;
|
||||||
|
|
||||||
// The arg/result shapes for the whole computation.
|
// The arg/result shapes for the whole computation.
|
||||||
xla.ProgramShape program_shape = 4;
|
xla.ProgramShapeProto program_shape = 4;
|
||||||
// The arg/result shapes for each core of a model-parallel
|
// The arg/result shapes for each core of a model-parallel
|
||||||
// computation. per_core_args_and_result_shapes is optional for a
|
// computation. per_core_args_and_result_shapes is optional for a
|
||||||
// single-core computation.
|
// single-core computation.
|
||||||
repeated xla.ProgramShape per_core_program_shape = 5;
|
repeated xla.ProgramShapeProto per_core_program_shape = 5;
|
||||||
// Describes how replicated computation instances should be assigned to
|
// Describes how replicated computation instances should be assigned to
|
||||||
// devices. There are num_cores_per_replica computations, and each one will be
|
// devices. There are num_cores_per_replica computations, and each one will be
|
||||||
// sent and executed to the set of replica device numbers described in the
|
// sent and executed to the set of replica device numbers described in the
|
||||||
|
Loading…
Reference in New Issue
Block a user