[XLA] Make HloInstruction::backend_config() a JSON-encoded protobuf.
PiperOrigin-RevId: 198754463
This commit is contained in:
parent
38a2a66fa9
commit
10fa513e15
@ -499,37 +499,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "scanner",
|
||||
srcs = ["scanner.cc"],
|
||||
hdrs = ["scanner.h"],
|
||||
visibility = [":internal"],
|
||||
deps = [
|
||||
":status",
|
||||
":status_macros",
|
||||
":types",
|
||||
":util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "scanner_test",
|
||||
srcs = ["scanner_test.cc"],
|
||||
deps = [
|
||||
":scanner",
|
||||
":status",
|
||||
":status_macros",
|
||||
":test",
|
||||
":types",
|
||||
":util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "text_literal_reader",
|
||||
srcs = ["text_literal_reader.cc"],
|
||||
|
@ -1,197 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/scanner.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
// Returns true if c can be the first character in an identifier.
|
||||
bool IsIdentifierFirst(int c) { return std::isalpha(c) || c == '_'; }
|
||||
|
||||
// Returns true if c can be the non-first character in an identifier.
|
||||
bool IsIdentifierLater(int c) { return std::isalnum(c) || c == '_'; }
|
||||
|
||||
// Returns true if str is an identifier.
|
||||
bool IsIdentifier(tensorflow::StringPiece str) {
|
||||
if (str.empty() || !IsIdentifierFirst(str[0])) {
|
||||
return false;
|
||||
}
|
||||
for (int64 i = 1; i < str.size(); ++i) {
|
||||
if (!IsIdentifierLater(str[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Scanner::Scanner(tensorflow::StringPiece input) : input_(input), position_(0) {}
|
||||
|
||||
bool Scanner::ok() const { return status().ok(); }
|
||||
|
||||
const Status& Scanner::status() const { return status_; }
|
||||
|
||||
bool Scanner::Match(tensorflow::StringPiece match) {
|
||||
SkipWhitespace();
|
||||
if (ok() && position_ + match.size() <= input_.size() &&
|
||||
std::equal(match.begin(), match.end(), input_.begin() + position_)) {
|
||||
SkipChars(match.size());
|
||||
|
||||
VLOG(10) << "Matched \"" << match << "\"";
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void Scanner::Expect(tensorflow::StringPiece expect) {
|
||||
if (!Match(expect)) {
|
||||
SetError(tensorflow::strings::StrCat("Expected \"", expect, "\"."));
|
||||
}
|
||||
}
|
||||
|
||||
bool Scanner::MatchReadIdentifier(string* identifier) {
|
||||
SkipWhitespace();
|
||||
if (!IsIdentifierFirst(PeekChar())) {
|
||||
return false;
|
||||
}
|
||||
identifier->clear();
|
||||
do {
|
||||
*identifier += ReadChar();
|
||||
} while (IsIdentifierLater(PeekChar()));
|
||||
|
||||
VLOG(10) << "Read identifier " << identifier;
|
||||
CHECK(IsIdentifier(*identifier));
|
||||
return true;
|
||||
}
|
||||
|
||||
string Scanner::ReadIdentifier() {
|
||||
string identifier;
|
||||
if (!MatchReadIdentifier(&identifier)) {
|
||||
SetError("Expected identifier.");
|
||||
}
|
||||
return identifier;
|
||||
}
|
||||
|
||||
void Scanner::ExpectIdentifier(tensorflow::StringPiece expect) {
|
||||
CHECK(IsIdentifier(expect));
|
||||
|
||||
string identifier;
|
||||
if (!MatchReadIdentifier(&identifier)) {
|
||||
SetError(tensorflow::strings::StrCat("Expected identifier ", expect, "."));
|
||||
}
|
||||
if (identifier != expect) {
|
||||
SetError(tensorflow::strings::StrCat("Expected identifier ", expect,
|
||||
", but got ", identifier, "."));
|
||||
}
|
||||
}
|
||||
|
||||
// Matches the end of the input, also known as End Of File (EOF).
|
||||
bool Scanner::MatchEof() {
|
||||
SkipWhitespace();
|
||||
return PeekChar() == EOF;
|
||||
}
|
||||
|
||||
void Scanner::ExpectEof() {
|
||||
if (!MatchEof()) {
|
||||
SetError("Expected end of input.");
|
||||
}
|
||||
}
|
||||
|
||||
// Reads a vector of the format "(1, 2, 3)".
|
||||
std::vector<int64> Scanner::ReadIntVector() {
|
||||
std::vector<int64> ints;
|
||||
Expect("(");
|
||||
if (!Match(")") && ok()) {
|
||||
ints.push_back(ReadInt());
|
||||
while (Match(",")) {
|
||||
ints.push_back(ReadInt());
|
||||
}
|
||||
Expect(")");
|
||||
}
|
||||
|
||||
VLOG(10) << "Read int vector with " << ints.size() << " elements.";
|
||||
return ints;
|
||||
}
|
||||
|
||||
int64 Scanner::ReadInt() {
|
||||
bool negative = Match("-");
|
||||
if (!PeekDigit()) {
|
||||
SetError("Expected integer.");
|
||||
return 0;
|
||||
}
|
||||
|
||||
int64 integer = 0;
|
||||
do {
|
||||
integer = (ReadChar() - '0') + integer * 10;
|
||||
} while (PeekDigit());
|
||||
integer = negative ? -integer : integer;
|
||||
|
||||
VLOG(10) << "Read integer " << integer;
|
||||
return integer;
|
||||
}
|
||||
|
||||
void Scanner::SkipWhitespace() {
|
||||
while (PeekWhitespace()) {
|
||||
SkipChars(1);
|
||||
}
|
||||
}
|
||||
|
||||
int Scanner::ReadChar() {
|
||||
int c = PeekChar();
|
||||
SkipChars(1);
|
||||
|
||||
VLOG(20) << "Read char " << c;
|
||||
return c;
|
||||
}
|
||||
|
||||
int Scanner::PeekChar() const {
|
||||
return ok() && position_ < input_.size() ? input_[position_] : EOF;
|
||||
}
|
||||
|
||||
bool Scanner::PeekDigit() const {
|
||||
// Do not use std::isdigit since it depends on the locale and we do not
|
||||
// handle any digits beyond 0-9.
|
||||
const char c = PeekChar();
|
||||
return '0' <= c && c <= '9';
|
||||
}
|
||||
|
||||
bool Scanner::PeekAlnum() const { return std::isalnum(PeekChar()); }
|
||||
|
||||
bool Scanner::PeekWhitespace() const { return std::isspace(PeekChar()); }
|
||||
|
||||
void Scanner::SkipChars(int64 count) {
|
||||
CHECK_GE(count, 0);
|
||||
position_ += count;
|
||||
}
|
||||
|
||||
void Scanner::SetError(string error_message) {
|
||||
// Only the first error is recorded since any later errors will likely be a
|
||||
// consequence of the first error.
|
||||
if (ok()) {
|
||||
status_ = InvalidArgumentStrCat(std::move(error_message));
|
||||
position_ = input_.size();
|
||||
VLOG(10) << "Failed scanner with error " << status_.ToString();
|
||||
} else {
|
||||
VLOG(10) << "Error on already failed scanner is " << error_message;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xla
|
@ -1,102 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SCANNER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SCANNER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Simple class for parsing data. The concepts for the interface are:
|
||||
//
|
||||
// Match(x): Returns true if x is next in the input and in that case skips
|
||||
// past it. Otherwise returns false.
|
||||
//
|
||||
// Expect(x): As Match(x), but requires x to be next in the input.
|
||||
//
|
||||
// MatchReadX(x): Returns true if an X is next in the input and in that case
|
||||
// skips past it and assigns it to x. Otherwise returns false.
|
||||
//
|
||||
// ReadX(): As ReadMatchX(), but requires an X to be next in the input and
|
||||
// returns it.
|
||||
//
|
||||
// PeekX(): Returns true if an X is next in the input and does not skip
|
||||
// past it either way.
|
||||
//
|
||||
// All of these, except those that work on individual characters, skip
|
||||
// whitespace.
|
||||
//
|
||||
// If a requirement is not met, the error is available in status(). A Scanner
|
||||
// with a failed status() will behave as though the rest of the input is EOF and
|
||||
// will not record further errors after that point.
|
||||
class Scanner {
|
||||
public:
|
||||
Scanner(tensorflow::StringPiece input);
|
||||
|
||||
bool ok() const;
|
||||
const Status& status() const;
|
||||
|
||||
bool Match(tensorflow::StringPiece match);
|
||||
void Expect(tensorflow::StringPiece expect);
|
||||
|
||||
// Match-reads an identifier. An identifier starts with an alphabetic
|
||||
// character or an underscore followed by any number of characters that are
|
||||
// each alphanumeric or underscore.
|
||||
bool MatchReadIdentifier(string* identifier);
|
||||
|
||||
string ReadIdentifier();
|
||||
|
||||
void ExpectIdentifier(tensorflow::StringPiece expect);
|
||||
|
||||
// Matches the end of the input, also known as End Of File (EOF).
|
||||
bool MatchEof();
|
||||
void ExpectEof();
|
||||
|
||||
// Reads a vector of the format "(1, 4, 5)".
|
||||
std::vector<int64> ReadIntVector();
|
||||
|
||||
// Reads an integer. Can start with a minus but not a plus.
|
||||
int64 ReadInt();
|
||||
|
||||
// Keeps skipping until encountering a non-whitespace character.
|
||||
void SkipWhitespace();
|
||||
|
||||
// *** Below here are character-level methods that do not skip whitespace.
|
||||
|
||||
int ReadChar();
|
||||
int PeekChar() const;
|
||||
bool PeekDigit() const;
|
||||
bool PeekAlnum() const;
|
||||
bool PeekWhitespace() const;
|
||||
|
||||
// Skip past the next count characters.
|
||||
void SkipChars(int64 count);
|
||||
|
||||
private:
|
||||
// Sets a failed status. The input is in effect replaced with EOF after
|
||||
// this. Only the first error is recorded.
|
||||
void SetError(string error_message);
|
||||
|
||||
const tensorflow::StringPiece input_;
|
||||
int64 position_;
|
||||
Status status_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SCANNER_H_
|
@ -1,124 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// TODO(b/80179519): Fix open source build for real.
|
||||
#if 0
|
||||
#include "tensorflow/compiler/xla/scanner.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
TEST(Scanner, Empty) {
|
||||
Scanner scanner("");
|
||||
|
||||
EXPECT_EQ(scanner.PeekChar(), EOF);
|
||||
EXPECT_TRUE(scanner.MatchEof());
|
||||
EXPECT_TRUE(scanner.Match(""));
|
||||
EXPECT_FALSE(scanner.Match("1"));
|
||||
EXPECT_TRUE(scanner.ok());
|
||||
}
|
||||
|
||||
TEST(Scanner, Prefix) {
|
||||
Scanner scanner("1234 5");
|
||||
EXPECT_FALSE(scanner.MatchEof());
|
||||
EXPECT_TRUE(scanner.Match("12"));
|
||||
EXPECT_TRUE(scanner.Match("34 "));
|
||||
EXPECT_FALSE(scanner.MatchEof());
|
||||
EXPECT_FALSE(scanner.Match("5 "));
|
||||
EXPECT_TRUE(scanner.Match("5"));
|
||||
EXPECT_TRUE(scanner.MatchEof());
|
||||
}
|
||||
|
||||
TEST(Scanner, Whitespace) {
|
||||
Scanner scanner(" \t\n\r 1\t2\n\n");
|
||||
|
||||
EXPECT_FALSE(scanner.Match(" "));
|
||||
EXPECT_TRUE(scanner.Match("1"));
|
||||
EXPECT_TRUE(scanner.Match("2"));
|
||||
EXPECT_TRUE(scanner.MatchEof());
|
||||
EXPECT_TRUE(scanner.ok());
|
||||
}
|
||||
|
||||
TEST(Scanner, Fail) {
|
||||
Scanner scanner("153 4q");
|
||||
|
||||
scanner.Expect("5");
|
||||
EXPECT_FALSE(scanner.ok());
|
||||
EXPECT_FALSE(scanner.status().ok());
|
||||
|
||||
EXPECT_TRUE(scanner.MatchEof());
|
||||
}
|
||||
|
||||
TEST(Scanner, Identifier) {
|
||||
Scanner scanner("1 q1 _1_ _1a= qqb");
|
||||
|
||||
string identifier = "foo";
|
||||
EXPECT_FALSE(scanner.MatchReadIdentifier(&identifier));
|
||||
EXPECT_EQ(identifier, "foo");
|
||||
scanner.Match("1");
|
||||
|
||||
EXPECT_TRUE(scanner.MatchReadIdentifier(&identifier));
|
||||
EXPECT_EQ(identifier, "q1");
|
||||
|
||||
scanner.ExpectIdentifier("_1_");
|
||||
EXPECT_TRUE(scanner.ok());
|
||||
|
||||
scanner.ExpectIdentifier("_1a");
|
||||
EXPECT_TRUE(scanner.ok());
|
||||
|
||||
// The = after _1a is not included in the identifier.
|
||||
scanner.Expect("=");
|
||||
|
||||
// The expected identifier matches a prefix but is not the full identifier in
|
||||
// the input.
|
||||
EXPECT_TRUE(scanner.ok());
|
||||
scanner.ExpectIdentifier("qq");
|
||||
EXPECT_FALSE(scanner.ok());
|
||||
}
|
||||
|
||||
TEST(Scanner, Int) {
|
||||
Scanner scanner("1_2 3% -1 124345 -363 0 -0");
|
||||
EXPECT_EQ(1, scanner.ReadInt());
|
||||
EXPECT_TRUE(scanner.Match("_"));
|
||||
EXPECT_EQ(2, scanner.ReadInt());
|
||||
EXPECT_EQ(3, scanner.ReadInt());
|
||||
EXPECT_TRUE(scanner.Match("%"));
|
||||
EXPECT_EQ(-1, scanner.ReadInt());
|
||||
EXPECT_EQ(124345, scanner.ReadInt());
|
||||
EXPECT_EQ(-363, scanner.ReadInt());
|
||||
EXPECT_EQ(0, scanner.ReadInt());
|
||||
EXPECT_EQ(0, scanner.ReadInt());
|
||||
EXPECT_TRUE(scanner.MatchEof());
|
||||
}
|
||||
|
||||
TEST(Scanner, IntVector) {
|
||||
Scanner scanner("()(0) (-1,2) ( 3 , 4 )");
|
||||
EXPECT_THAT(scanner.ReadIntVector(), testing::IsEmpty());
|
||||
EXPECT_THAT(scanner.ReadIntVector(), testing::ElementsAre(0));
|
||||
EXPECT_THAT(scanner.ReadIntVector(), testing::ElementsAre(-1, 2));
|
||||
EXPECT_THAT(scanner.ReadIntVector(), testing::ElementsAre(3, 4));
|
||||
EXPECT_TRUE(scanner.MatchEof());
|
||||
EXPECT_TRUE(scanner.ok());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
#endif
|
@ -309,6 +309,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:human_readable_json",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
|
@ -28,8 +28,9 @@ namespace xla {
|
||||
/* static */ tensorflow::mutex Compiler::platform_compiler_mutex_(
|
||||
tensorflow::LINKER_INITIALIZED);
|
||||
|
||||
std::vector<string> Compiler::ComputeBackendConfigs(
|
||||
const HloInstruction& hlo, se::StreamExecutor* executor) const {
|
||||
std::vector<std::unique_ptr<tensorflow::protobuf::Message>>
|
||||
Compiler::ComputeBackendConfigs(const HloInstruction& hlo,
|
||||
se::StreamExecutor* executor) const {
|
||||
CHECK(executor != nullptr);
|
||||
return {};
|
||||
}
|
||||
|
@ -36,6 +36,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
|
||||
@ -161,8 +162,9 @@ class Compiler {
|
||||
//
|
||||
// The stream executor is passed in to provide information about the hardware
|
||||
// that the backend configurations would be targeting.
|
||||
virtual std::vector<string> ComputeBackendConfigs(
|
||||
const HloInstruction& hlo, se::StreamExecutor* executor) const;
|
||||
virtual std::vector<std::unique_ptr<tensorflow::protobuf::Message>>
|
||||
ComputeBackendConfigs(const HloInstruction& hlo,
|
||||
se::StreamExecutor* executor) const;
|
||||
|
||||
// Compiles the HLO module for ahead-of-time execution. This is intended for
|
||||
// use in static compilation.
|
||||
|
@ -1085,11 +1085,11 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
|
||||
|
||||
string HloDotDumper::GetInstructionNodeBackendConfig(
|
||||
const HloInstruction* instr) {
|
||||
if (!show_backend_config_ || instr->backend_config().empty()) {
|
||||
if (!show_backend_config_ || instr->raw_backend_config_string().empty()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
return StrCat("backend_config=\"", instr->backend_config(), "\"");
|
||||
return StrCat("backend_config=\"", instr->raw_backend_config_string(), "\"");
|
||||
}
|
||||
|
||||
string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
|
||||
|
@ -41,6 +41,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/human_readable_json.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
@ -110,7 +111,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
instruction->name_ = proto.name();
|
||||
|
||||
instruction->metadata_ = proto.metadata();
|
||||
instruction->set_backend_config(proto.backend_config());
|
||||
instruction->backend_config_ = proto.backend_config();
|
||||
if (proto.has_literal()) {
|
||||
TF_ASSIGN_OR_RETURN(instruction->literal_,
|
||||
Literal::CreateFromProto(proto.literal()));
|
||||
@ -1521,7 +1522,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
}
|
||||
SetupDerivedInstruction(clone.get());
|
||||
clone->set_parent(parent_);
|
||||
clone->set_backend_config(backend_config());
|
||||
clone->set_raw_backend_config_string(backend_config_);
|
||||
if (context != nullptr) {
|
||||
context->MapInstruction(this, clone.get());
|
||||
clone->ReplaceCalledComputations([&](HloComputation* callee) {
|
||||
@ -2182,8 +2183,8 @@ string HloInstruction::ToStringWithCanonicalNameMap(
|
||||
!metadata_.source_file().empty())) {
|
||||
StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}");
|
||||
}
|
||||
if (options.print_backend_config() && !backend_config().empty()) {
|
||||
StrAppend(&result, ", backend_config=\"", CEscape(backend_config()), "\"");
|
||||
if (options.print_backend_config() && !backend_config_.empty()) {
|
||||
StrAppend(&result, ", backend_config=\"", CEscape(backend_config_), "\"");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@ -2463,7 +2464,7 @@ HloInstructionProto HloInstruction::ToProto() const {
|
||||
}
|
||||
|
||||
*proto.mutable_metadata() = metadata_;
|
||||
proto.set_backend_config(backend_config());
|
||||
proto.set_backend_config(backend_config_);
|
||||
if (literal_ != nullptr) {
|
||||
*proto.mutable_literal() = literal_->ToProto();
|
||||
}
|
||||
@ -3526,6 +3527,31 @@ bool HloInstruction::CouldBeBitcast() const {
|
||||
}
|
||||
}
|
||||
|
||||
Status HloInstruction::GetBackendConfigInternal(
|
||||
tensorflow::protobuf::Message* proto) const {
|
||||
proto->Clear();
|
||||
|
||||
// Empty string does not parse as valid JSON, but it's a valid backend config,
|
||||
// corresponding to the empty proto.
|
||||
if (backend_config_.empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
return tensorflow::HumanReadableJsonToProto(backend_config_, proto);
|
||||
}
|
||||
|
||||
Status HloInstruction::set_backend_config(
|
||||
const tensorflow::protobuf::Message& proto) {
|
||||
TF_ASSIGN_OR_RETURN(backend_config_, BackendConfigToRawString(proto));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/* static */ StatusOr<string> HloInstruction::BackendConfigToRawString(
|
||||
const tensorflow::protobuf::Message& proto) {
|
||||
string ret;
|
||||
TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(proto, &ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
HloModule* HloInstruction::GetModule() const {
|
||||
if (parent_) {
|
||||
return parent_->parent();
|
||||
|
@ -52,6 +52,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/iterator_range.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
@ -1446,12 +1447,33 @@ class HloInstruction {
|
||||
// this field and they cannot interpret it due to its meaning being backend
|
||||
// specific.
|
||||
//
|
||||
// TODO(b/78194644): Introduce structured configuration format as per
|
||||
// go/xla-heuristics.
|
||||
const string& backend_config() const { return backend_config_; }
|
||||
void set_backend_config(string backend_config) {
|
||||
backend_config_ = std::move(backend_config);
|
||||
// ConfigProto should be a protobuf Message type.
|
||||
template <typename ConfigProto>
|
||||
StatusOr<ConfigProto> backend_config() const {
|
||||
ConfigProto proto;
|
||||
TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto));
|
||||
return std::move(proto);
|
||||
}
|
||||
Status set_backend_config(const tensorflow::protobuf::Message& proto);
|
||||
|
||||
// Getter/setter for raw JSON-encoded backend config. Prefer the
|
||||
// functions above that deal in proto Messages where possible.
|
||||
const string& raw_backend_config_string() const { return backend_config_; }
|
||||
void set_raw_backend_config_string(string config_str) {
|
||||
backend_config_ = std::move(config_str);
|
||||
}
|
||||
|
||||
// Returns a string representation of a proto in the format used by
|
||||
// raw_backend_config_string.
|
||||
//
|
||||
// This is morally equivalent to:
|
||||
//
|
||||
// HloInstruction instr;
|
||||
// TF_RETURN_IF_ERROR(instr.set_backend_config(proto));
|
||||
// return instr.raw_backend_config_string();
|
||||
//
|
||||
static StatusOr<string> BackendConfigToRawString(
|
||||
const tensorflow::protobuf::Message& proto);
|
||||
|
||||
// Sets the debug metadata for this instruction.
|
||||
void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
|
||||
@ -1573,6 +1595,10 @@ class HloInstruction {
|
||||
// Returns how this instruction uses elements of its `i`th operand.
|
||||
UseKind OperandElementUse(int64 i) const;
|
||||
|
||||
// Helper for implementing backend_config(). Parses backend_config_ into the
|
||||
// given proto.
|
||||
Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const;
|
||||
|
||||
int unique_id_; // Unique to this HloInstruction within a HloModule
|
||||
|
||||
// Opcode for this instruction.
|
||||
|
@ -1127,7 +1127,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
instruction->set_metadata(*metadata);
|
||||
}
|
||||
if (backend_config) {
|
||||
instruction->set_backend_config(std::move(*backend_config));
|
||||
instruction->set_raw_backend_config_string(std::move(*backend_config));
|
||||
}
|
||||
return AddInstruction(name, instruction, name_loc);
|
||||
} // NOLINT(readability/fn_size)
|
||||
|
@ -1025,7 +1025,7 @@ ENTRY %configuration_test() -> s32[] {
|
||||
EXPECT_EQ("foo bar", result.ValueOrDie()
|
||||
->entry_computation()
|
||||
->root_instruction()
|
||||
->backend_config());
|
||||
->raw_backend_config_string());
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, LiteralDimensionsMismatch_1) {
|
||||
|
@ -101,42 +101,43 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_only_cc_test")
|
||||
# For platform specific build config
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"tf_additional_all_protos",
|
||||
"tf_additional_cloud_kernel_deps",
|
||||
"tf_additional_cloud_op_deps",
|
||||
"tf_additional_core_deps",
|
||||
"tf_additional_cupti_wrapper_deps",
|
||||
"tf_additional_device_tracer_cuda_deps",
|
||||
"tf_additional_device_tracer_deps",
|
||||
"tf_additional_device_tracer_srcs",
|
||||
"tf_additional_gdr_lib_defines",
|
||||
"tf_additional_human_readable_json_deps",
|
||||
"tf_additional_lib_defines",
|
||||
"tf_additional_lib_deps",
|
||||
"tf_additional_libdevice_data",
|
||||
"tf_additional_libdevice_deps",
|
||||
"tf_additional_libdevice_srcs",
|
||||
"tf_additional_lib_hdrs",
|
||||
"tf_additional_lib_srcs",
|
||||
"tf_additional_minimal_lib_srcs",
|
||||
"tf_additional_mpi_lib_defines",
|
||||
"tf_additional_proto_hdrs",
|
||||
"tf_additional_proto_srcs",
|
||||
"tf_additional_test_deps",
|
||||
"tf_additional_test_srcs",
|
||||
"tf_additional_verbs_lib_defines",
|
||||
"tf_jspb_proto_library",
|
||||
"tf_kernel_tests_linkstatic",
|
||||
"tf_lib_proto_parsing_deps",
|
||||
"tf_nano_proto_library",
|
||||
"tf_platform_hdrs",
|
||||
"tf_platform_srcs",
|
||||
"tf_proto_library",
|
||||
"tf_proto_library_cc",
|
||||
"tf_additional_all_protos",
|
||||
"tf_additional_core_deps",
|
||||
"tf_additional_lib_defines",
|
||||
"tf_additional_lib_deps",
|
||||
"tf_additional_lib_hdrs",
|
||||
"tf_additional_lib_srcs",
|
||||
"tf_additional_minimal_lib_srcs",
|
||||
"tf_additional_proto_hdrs",
|
||||
"tf_additional_proto_srcs",
|
||||
"tf_additional_cupti_wrapper_deps",
|
||||
"tf_additional_libdevice_data",
|
||||
"tf_additional_libdevice_deps",
|
||||
"tf_additional_libdevice_srcs",
|
||||
"tf_additional_test_deps",
|
||||
"tf_additional_test_srcs",
|
||||
"tf_kernel_tests_linkstatic",
|
||||
"tf_additional_cloud_op_deps",
|
||||
"tf_additional_cloud_kernel_deps",
|
||||
"tf_lib_proto_parsing_deps",
|
||||
"tf_additional_verbs_lib_defines",
|
||||
"tf_additional_mpi_lib_defines",
|
||||
"tf_additional_gdr_lib_defines",
|
||||
"tf_additional_device_tracer_srcs",
|
||||
"tf_additional_device_tracer_deps",
|
||||
"tf_additional_device_tracer_cuda_deps",
|
||||
"tf_pyclif_proto_library",
|
||||
"tf_jspb_proto_library",
|
||||
"tf_nano_proto_library",
|
||||
"tf_protos_all",
|
||||
"tf_protos_all_impl",
|
||||
"tf_protos_grappler",
|
||||
"tf_protos_grappler_impl",
|
||||
"tf_pyclif_proto_library",
|
||||
)
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
@ -400,6 +401,7 @@ cc_library(
|
||||
"protobuf.cc",
|
||||
]) + [
|
||||
"platform/protobuf_util.cc",
|
||||
"lib/core/status.h",
|
||||
],
|
||||
hdrs = [
|
||||
":platform_protobuf_hdrs",
|
||||
@ -416,6 +418,18 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "human_readable_json",
|
||||
srcs = tf_platform_srcs(["human_readable_json.cc"]),
|
||||
hdrs = ["platform/human_readable_json.h"],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":lib",
|
||||
":lib_internal",
|
||||
] + tf_additional_human_readable_json_deps(),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "platform_env_hdrs",
|
||||
srcs = [
|
||||
@ -2013,6 +2027,7 @@ cc_library(
|
||||
"platform/**/cuda_libdevice_path.cc",
|
||||
"platform/**/device_tracer.cc",
|
||||
"platform/**/logging.cc",
|
||||
"platform/**/human_readable_json.cc",
|
||||
"platform/abi.cc",
|
||||
],
|
||||
) + tf_additional_lib_srcs(
|
||||
@ -2025,6 +2040,7 @@ cc_library(
|
||||
"platform/**/env_time.cc",
|
||||
"platform/**/device_tracer.cc",
|
||||
"platform/**/logging.cc",
|
||||
"platform/**/human_readable_json.cc",
|
||||
"platform/abi.cc",
|
||||
] +
|
||||
# Protobuf deps already included through the ":lib_proto_parsing"
|
||||
|
@ -515,6 +515,9 @@ def tf_additional_proto_srcs():
|
||||
"platform/default/protobuf.cc",
|
||||
]
|
||||
|
||||
def tf_additional_human_readable_json_deps():
|
||||
return []
|
||||
|
||||
def tf_additional_all_protos():
|
||||
return ["//tensorflow/core:protos_all"]
|
||||
|
||||
|
54
tensorflow/core/platform/default/human_readable_json.cc
Normal file
54
tensorflow/core/platform/default/human_readable_json.cc
Normal file
@ -0,0 +1,54 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/human_readable_json.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status ProtoToHumanReadableJson(const ::google::protobuf::Message& proto,
|
||||
string* result) {
|
||||
result->clear();
|
||||
|
||||
auto status = google::protobuf::util::MessageToJsonString(proto, result);
|
||||
if (!status.ok()) {
|
||||
// Convert error_msg google::protobuf::StringPiece to
|
||||
// tensorflow::StringPiece.
|
||||
auto error_msg = status.error_message();
|
||||
return errors::Internal(
|
||||
strings::StrCat("Could not convert proto to JSON string: ",
|
||||
StringPiece(error_msg.data(), error_msg.length())));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HumanReadableJsonToProto(const string& str,
|
||||
::google::protobuf::Message* proto) {
|
||||
proto->Clear();
|
||||
auto status = google::protobuf::util::JsonStringToMessage(str, proto);
|
||||
if (!status.ok()) {
|
||||
// Convert error_msg google::protobuf::StringPiece to
|
||||
// tensorflow::StringPiece.
|
||||
auto error_msg = status.error_message();
|
||||
return errors::Internal(
|
||||
strings::StrCat("Could not convert JSON string to proto: ",
|
||||
StringPiece(error_msg.data(), error_msg.length())));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
37
tensorflow/core/platform/human_readable_json.h
Normal file
37
tensorflow/core/platform/human_readable_json.h
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_
|
||||
#define TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Converts a proto to a JSON-like string that's meant to be human-readable
|
||||
// but still machine-parseable.
|
||||
//
|
||||
// This string may not be strictly JSON-compliant, but it must be parseable by
|
||||
// HumanReadableJSONToProto.
|
||||
Status ProtoToHumanReadableJson(const protobuf::Message& proto, string* result);
|
||||
|
||||
// Converts a string produced by ProtoToHumanReadableJSON to a protobuf. Not
|
||||
// guaranteed to work for general JSON.
|
||||
Status HumanReadableJsonToProto(const string& str, protobuf::Message* proto);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_
|
Loading…
x
Reference in New Issue
Block a user