From 21d2de1c8d34d5094472dd828394c239d6111e0d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Oct 2017 15:02:01 -0700 Subject: [PATCH 01/41] Add a recursive descent parser for the HloModule string. It constructs an HloModule object from a string printed by HloModule::ToString(). This is a initial stage. It currently supports: - unary, binary, ternary ops, and other ops that don't have extra attributes. - module with entry computation only. - simple cases for constant instruction. To make the parser simpler, this cl removes a whitespace and adds a '%' before the computation name in HloComputation::ToString(). Further steps will enable parsing subcomputations, more cases of constants, tuple, and ops that require extra attributes (e.g., broadcast dimensions, subcomputation). PiperOrigin-RevId: 172804214 --- tensorflow/BUILD | 1 + .../compiler/xla/service/hlo_computation.cc | 4 +- tensorflow/compiler/xla/shape_util.cc | 45 +- tensorflow/compiler/xla/tools/parser/BUILD | 84 +++ .../compiler/xla/tools/parser/README.md | 69 +++ .../compiler/xla/tools/parser/hlo_lexer.cc | 270 ++++++++++ .../compiler/xla/tools/parser/hlo_lexer.h | 108 ++++ .../compiler/xla/tools/parser/hlo_parser.cc | 502 ++++++++++++++++++ .../compiler/xla/tools/parser/hlo_parser.h | 37 ++ .../xla/tools/parser/hlo_parser_test.cc | 240 +++++++++ .../compiler/xla/tools/parser/hlo_token.h | 58 ++ 11 files changed, 1402 insertions(+), 16 deletions(-) create mode 100644 tensorflow/compiler/xla/tools/parser/BUILD create mode 100644 tensorflow/compiler/xla/tools/parser/README.md create mode 100644 tensorflow/compiler/xla/tools/parser/hlo_lexer.cc create mode 100644 tensorflow/compiler/xla/tools/parser/hlo_lexer.h create mode 100644 tensorflow/compiler/xla/tools/parser/hlo_parser.cc create mode 100644 tensorflow/compiler/xla/tools/parser/hlo_parser.h create mode 100644 tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc create mode 100644 tensorflow/compiler/xla/tools/parser/hlo_token.h diff --git a/tensorflow/BUILD b/tensorflow/BUILD index e351037abbd..d5c56cdc184 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -340,6 +340,7 @@ filegroup( "//tensorflow/compiler/xla/service/llvm_ir:all_files", "//tensorflow/compiler/xla/tests:all_files", "//tensorflow/compiler/xla/tools:all_files", + "//tensorflow/compiler/xla/tools/parser:all_files", "//tensorflow/contrib:all_files", "//tensorflow/contrib/all_reduce:all_files", "//tensorflow/contrib/android:all_files", diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 9b3104eaacd..51ead753f04 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -373,8 +373,8 @@ string HloComputation::ToString(int nested_level) const { for (int i = 0; i < nested_level; i++) { s << " "; } - s << name() << " " << ShapeUtil::HumanString(ComputeProgramShape()) - << " { \n"; + s << "%" << name() << " " << ShapeUtil::HumanString(ComputeProgramShape()) + << " {\n"; for (const HloInstruction* instruction : MakeInstructionPostOrder()) { for (int i = 0; i < nested_level; i++) { s << " "; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 8e16056b239..af583bed625 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -102,6 +102,32 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { return true; } +// Constructs and returns the new shape with the given minor_to_major order in +// its Layout. +StatusOr MakeShapeWithLayoutInternal( + PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice minor_to_major) { + if (dimensions.size() != minor_to_major.size()) { + return InvalidArgument("Dimensions size is %ld, but layout size is %ld.", + dimensions.size(), minor_to_major.size()); + } + if (element_type == OPAQUE || element_type == TUPLE) { + return InvalidArgument("Unsupported element type: %s", + PrimitiveType_Name(element_type).c_str()); + } + Shape shape = ShapeUtil::MakeShape(element_type, dimensions); + auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); + min2maj->Clear(); + for (int64 value : minor_to_major) { + min2maj->Add(value); + } + if (!shape.has_layout()) { + return InvalidArgument("Shape has no layout."); + } + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); + return shape; +} + } // namespace /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) { @@ -152,16 +178,8 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { /* static */ Shape ShapeUtil::MakeShapeWithLayout( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice minor_to_major) { - CHECK_EQ(dimensions.size(), minor_to_major.size()); - Shape shape = MakeShape(element_type, dimensions); - auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); - min2maj->Clear(); - for (int64 value : minor_to_major) { - min2maj->Add(value); - } - DCHECK(shape.has_layout()); - TF_DCHECK_OK(ValidateShape(shape)); - return shape; + return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major) + .ValueOrDie(); } /* static */ Shape ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( @@ -499,11 +517,10 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { // Extract the layout minor-to-major and set it. TF_ASSIGN_OR_RETURN(std::vector min2maj, comma_list_to_int64s(layout_string)); - TF_RET_CHECK(dimensions.size() == min2maj.size()); - result = - ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, min2maj); + TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal( + primitive_type, dimensions, min2maj)); } - TF_DCHECK_OK(ShapeUtil::ValidateShape(result)); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result)); return std::move(result); } diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD new file mode 100644 index 00000000000..c84ca9fc833 --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/BUILD @@ -0,0 +1,84 @@ +# Build file for the Hlo parser. + +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [":friends"], +) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +cc_library( + name = "hlo_lexer", + srcs = ["hlo_lexer.cc"], + hdrs = [ + "hlo_lexer.h", + "hlo_token.h", + ], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + ], +) + +cc_library( + name = "hlo_parser", + srcs = ["hlo_parser.cc"], + hdrs = ["hlo_parser.h"], + deps = [ + ":hlo_lexer", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_cc_test( + name = "hlo_parser_test", + size = "small", + srcs = ["hlo_parser_test.cc"], + deps = [ + ":hlo_parser", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md new file mode 100644 index 00000000000..a334bc2b297 --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/README.md @@ -0,0 +1,69 @@ +# HloModule string syntax + +TODO: Support subcomputations (for fusion, reduce, while, ...). + +TODO: Support ops that require extra attributes, e.g. dimensions, strides. + +```yacc +hlo_module + : 'HloModule' name computation + ; + +computation + : 'ENTRY' name param_list '->' shape instruction_list + ; + +instruction_list + : '{' instruction_list1 '}' + ; +instruction_list1 + : instruction + | instruction_list1 instruction + ; +instruction + : name '=' shape opcode operands + ; + +operands + : '(' operands1 ')' + ; +operands1 + : /*empty*/ + | operand + | operands1 ',' operand + ; +operand + : shape name + ; + +param_list + : '(' param_list1 ')' + ; +param_list1 + : /*empty*/ + | param + | param_list1 ',' param + ; +param + : name shape + ; + +shape + : shape_val_ + | '(' tuple_elements ')' + ; +tuple_elements + : /*empty*/ + | shape (',' shape)* + ; + +name + : identifier ':' + | '%' identifier + ; + +identifier + : [a-zA-Z_][a-zA-Z0-9_.-]* + ; + +``` diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc new file mode 100644 index 00000000000..3e84ffcbd2c --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc @@ -0,0 +1,270 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" + +#include + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/platform/regexp.h" + +namespace xla { +namespace tools { + +using tensorflow::StringPiece; + +namespace { + +constexpr int kEOF = -1; +constexpr int kError = -2; + +// [a-zA-Z0-9_.-] +bool IsIdentifierChar(char c) { + return isalnum(static_cast(c)) || c == '-' || c == '.' || + c == '_'; +} + +} // namespace + +int HloLexer::GetNextChar() { + int current_char = PeekCurrentChar(); + if (current_char != kEOF && current_char != kError) { + current_ptr_++; + } + return current_char; +} + +int HloLexer::PeekCurrentChar() const { + if (current_ptr_ == buf_.end()) { + return kEOF; + } + char current_char = *current_ptr_; + if (current_char == 0) { + // '\0' should not appear in the middle of the string. + return kError; + } + return static_cast(current_char); +} + +bool HloLexer::CanDereference(const char* ptr) const { + return ptr < buf_.end() && ptr >= buf_.begin(); +} + +StringPiece HloLexer::StringPieceFromPointers(const char* begin, + const char* end) const { + CHECK(begin <= end); + CHECK(begin == buf_.end() || CanDereference(begin)); + CHECK(end == buf_.end() || CanDereference(end)); + return StringPiece(begin, end - begin); +} + +tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( + const char* begin, const char* end) const { + CHECK(begin <= end); + CHECK(begin == buf_.end() || CanDereference(begin)); + CHECK(end == buf_.end() || CanDereference(end)); + return tensorflow::RegexpStringPiece(begin, end - begin); +} + +TokKind HloLexer::LexToken() { + while (true) { + token_start_ = current_ptr_; + + int current_char = GetNextChar(); + switch (current_char) { + default: + // [a-zA-Z_] + if (isalpha(static_cast(current_char)) || + current_char == '_') { + return LexIdentifier(); + } + return TokKind::kError; + case kEOF: + // Hit the end of the input buffer. + return TokKind::kEof; + case kError: + // Hit an invalid character in the input buffer. + return TokKind::kError; + case ' ': + case '\t': + case '\n': + case '\r': + // Ignore whitespace. + continue; + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + case '-': + if (current_char == '-' && PeekCurrentChar() == '>') { + current_ptr_++; + return TokKind::kArrow; + } + return LexDigitOrNegative(); + case '=': + return TokKind::kEqual; + case ',': + return TokKind::kComma; + case '%': + return LexPercent(); + case ':': + return TokKind::kColon; + case '[': + return TokKind::kLsquare; + case ']': + return TokKind::kRsquare; + case '{': + return TokKind::kLbrace; + case '}': + return TokKind::kRbrace; + case '(': + return TokKind::kLparen; + case ')': + return TokKind::kRparen; + } + } +} + +// Lex a shape, name, keyword, or opcode. +// shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})? +// name ::= [a-zA-Z_][a-zA-Z0-9_.-]*: +// keyword ::= HloModule, ENTRY, ... +// opcode ::= add, greater-than, ... +TokKind HloLexer::LexIdentifier() { + { + auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + // 'consumable' will be advanced iff its prefix matches the pattern. + static LazyRE2 shape_pattern = { + R"(^(\w*\d*)\[([\d,]*)\](?:\s*{([\d,]*)})?)"}; + if (RE2::Consume(&consumable, *shape_pattern)) { + auto status_or_shape = ShapeUtil::ParseShapeString( + StringPieceFromPointers(token_start_, consumable.begin())); + if (status_or_shape.ok()) { + // This is a shape string. + shape_val_ = status_or_shape.ValueOrDie(); + current_ptr_ = consumable.begin(); + return TokKind::kShape; + } + } + } + + while (IsIdentifierChar(PeekCurrentChar())) { + current_ptr_++; + } + + // If followed by ':', it's a name. + if (PeekCurrentChar() == ':') { + str_val_.assign(token_start_, current_ptr_); + current_ptr_++; // skip ':' + return TokKind::kName; + } + + StringPiece identifier = StringPieceFromPointers(token_start_, current_ptr_); + + // See if this is a keyword. +#define KEYWORD(STR) \ + do { \ + if (identifier == #STR) { \ + return TokKind::kw_##STR; \ + } \ + } while (false) + + KEYWORD(true); + KEYWORD(false); + KEYWORD(HloModule); + KEYWORD(ENTRY); + +#undef KEYWORD + + // See if this is an opcode. + auto opcode = StringToHloOpcode(identifier.ToString()); + if (opcode.ok()) { + opcode_val_ = opcode.ValueOrDie(); + return TokKind::kOpcode; + } + + current_ptr_ = token_start_ + 1; + return TokKind::kError; +} + +// Lex names after a % character. +// name ::= [a-zA-Z_][a-zA-Z0-9_.-]* +TokKind HloLexer::LexPercent() { + const char* name_start = current_ptr_; + if (isalpha(static_cast(PeekCurrentChar())) || + PeekCurrentChar() == '_') { + current_ptr_++; + while (IsIdentifierChar(PeekCurrentChar())) { + current_ptr_++; + } + str_val_.assign(name_start, current_ptr_); + return TokKind::kName; + } + return TokKind::kError; +} + +// Lex integer and floating-point values. +// int [-]?[0-9]+ +// fp with exp [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+) +// fp without exp [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+) +TokKind HloLexer::LexDigitOrNegative() { + auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + static LazyRE2 float_pattern = { + R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|(\d+[.]\d*|\d*[.]\d+))"}; + if (RE2::Consume(&consumable, *float_pattern)) { + current_ptr_ = consumable.begin(); + tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(), + &decimal_val_); + return TokKind::kDecimal; + } + + static LazyRE2 int_pattern = {R"([-]?\d+)"}; + if (RE2::Consume(&consumable, *int_pattern)) { + current_ptr_ = consumable.begin(); + tensorflow::strings::safe_strto64( + StringPieceFromPointers(token_start_, current_ptr_), &int64_val_); + return TokKind::kInt; + } + + return TokKind::kError; +} + +StringPiece HloLexer::GetCurrentLine() const { + const char* start = token_start_; + const char* end = current_ptr_; + if (!CanDereference(start) || !CanDereference(end)) { + return "LINE OUT OF RANGE"; + } + while (start > buf_.begin() && *start != '\n') { + start--; + } + while (end < buf_.end() && *end != '\n') { + end++; + } + return StringPieceFromPointers(start, end); +} + +} // namespace tools +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h new file mode 100644 index 00000000000..20278fd6cde --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h @@ -0,0 +1,108 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_token.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace tools { + +// Lexer for the HloModule::ToString() format text. +class HloLexer { + public: + explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) { + current_ptr_ = buf_.begin(); + } + + TokKind Lex() { return current_kind_ = LexToken(); } + TokKind GetKind() const { return current_kind_; } + string GetStrVal() const { + CHECK(GetKind() == TokKind::kName); + return str_val_; + } + Shape GetShapeVal() const { + CHECK(GetKind() == TokKind::kShape); + return shape_val_; + } + HloOpcode GetOpcodeVal() const { + CHECK(GetKind() == TokKind::kOpcode); + return opcode_val_; + } + int64 GetInt64Val() const { + CHECK(GetKind() == TokKind::kInt); + return int64_val_; + } + double GetDecimalVal() const { + CHECK(GetKind() == TokKind::kDecimal); + return decimal_val_; + } + + // Returns the line of text that is currently being lexed. + tensorflow::StringPiece GetCurrentLine() const; + + private: + // Returns the current character. If it's neither the end of input buffer nor + // an invalid character, moves the pointer forward. + int GetNextChar(); + + // Returns the current character. + int PeekCurrentChar() const; + + // Creates StringPiece with the given begin and end. Exits if the begin > end, + // or it's out of the range of the current buffer. + tensorflow::StringPiece StringPieceFromPointers(const char* begin, + const char* end) const; + tensorflow::RegexpStringPiece RegexpStringPieceFromPointers( + const char* begin, const char* end) const; + + // Returns true if the given ptr is dereferenceable within the range of the + // current buffer. + bool CanDereference(const char* ptr) const; + + TokKind LexToken(); + + TokKind LexIdentifier(); + TokKind LexPercent(); + TokKind LexShape(); + TokKind LexConstant(); + TokKind LexDigitOrNegative(); + + const tensorflow::StringPiece buf_; + const char* current_ptr_; + + // Information about the current token. + const char* token_start_; + TokKind current_kind_; + string str_val_; + Shape shape_val_; + HloOpcode opcode_val_; + int64 int64_val_; + double decimal_val_; +}; + +} // namespace tools +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc new file mode 100644 index 00000000000..57700493e6c --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -0,0 +1,502 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { +namespace tools { + +namespace { + +using tensorflow::StringPiece; +using tensorflow::strings::StrCat; + +// Parser for the HloModule::ToString() format text. +class HloParser { + public: + explicit HloParser(StringPiece str) : lexer_(str) {} + + // Runs the parser. Returns false if an error occurred. + bool Run(); + + // Returns the parsed HloModule. + std::unique_ptr ConsumeHloModule() { return std::move(module_); } + + // Returns the error information. + string GetError() const { return tensorflow::str_util::Join(error_, "\n"); } + + private: + // ParseXXX returns false if an error occurred. + bool ParseHloModule(); + bool ParseComputation(); + bool ParseInstructionList(HloComputation::Builder* builder); + bool ParseInstruction(HloComputation::Builder* builder); + bool ParseLiteral(std::unique_ptr* literal, const Shape& shape); + bool ParseOperands(std::vector* operands, + const int expected_size); + bool ParseParamList(); + bool ParseName(string* result); + bool ParseShape(Shape* result); + bool ParseOpcode(HloOpcode* result); + bool ParseInt64(int64* result); + bool ParseDecimal(double* result); + bool ParseBool(bool* result); + bool ParseToken(TokKind kind, const string& msg); + + // Logs the current parsing line and the given message. Always returns false. + bool TokenError(StringPiece msg); + + // If the current token is 'kind', eats it (i.e. lexes the next token) and + // returns true. + bool EatIfPresent(TokKind kind); + + // Adds the instruction to the pool. Returns false and emits an error if the + // instruction already exists. + bool AddInstruction(const string& name, HloInstruction* instruction); + + // The map from the instruction name to the instruction. This does not own the + // instructions. + std::unordered_map instruction_pool_; + + HloLexer lexer_; + std::unique_ptr module_; + std::vector error_; +}; + +bool HloParser::TokenError(StringPiece msg) { + error_.push_back( + StrCat("was parsing \"", lexer_.GetCurrentLine(), "\"; ", msg)); + return false; +} + +bool HloParser::Run() { + lexer_.Lex(); + return ParseHloModule(); +} + +// ::= 'HloModule' name computation +bool HloParser::ParseHloModule() { + if (lexer_.GetKind() != TokKind::kw_HloModule) { + return TokenError("expects HloModule"); + } + // Eat 'HloModule' + lexer_.Lex(); + + string name; + if (!ParseName(&name)) { + return false; + } + + module_ = MakeUnique(name); + + return ParseComputation(); +} + +// computation ::= 'ENTRY' name param_list '->' shape instruction_list +bool HloParser::ParseComputation() { + string name; + if (!ParseToken(TokKind::kw_ENTRY, "expects 'ENTRY'") || !ParseName(&name)) { + return false; + } + auto builder = MakeUnique(name); + + Shape shape; + if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'") || + !ParseShape(&shape) || !ParseInstructionList(builder.get())) { + return false; + } + module_->AddEntryComputation(builder->Build()); + return true; +} + +// instruction_list ::= '{' instruction_list1 '}' +// instruction_list1 ::= (instruction)+ +bool HloParser::ParseInstructionList(HloComputation::Builder* builder) { + if (!ParseToken(TokKind::kLbrace, + "expects '{' at the beginning of instruction list.")) { + return false; + } + do { + if (!ParseInstruction(builder)) { + return false; + } + } while (lexer_.GetKind() != TokKind::kRbrace); + return ParseToken(TokKind::kRbrace, + "expects '}' at the end of instruction list."); +} + +// instruction ::= name '=' shape opcode operands +bool HloParser::ParseInstruction(HloComputation::Builder* builder) { + string name; + Shape shape; + HloOpcode opcode; + std::vector operands; + if (!ParseName(&name) || + !ParseToken(TokKind::kEqual, "expects '=' in instruction") || + !ParseShape(&shape) || !ParseOpcode(&opcode)) { + return false; + } + switch (opcode) { + case HloOpcode::kParameter: { + int64 parameter_number; + return ParseToken(TokKind::kLparen, + "expects '(' before parameter number") && + ParseInt64(¶meter_number) && + ParseToken(TokKind::kRparen, + "expects ')' after parameter number") && + AddInstruction( + name, builder->AddInstruction(HloInstruction::CreateParameter( + parameter_number, shape, name))); + } + case HloOpcode::kConstant: { + std::unique_ptr literal; + return ParseToken(TokKind::kLparen, + "expects '(' before parameter number") && + ParseLiteral(&literal, shape) && + ParseToken(TokKind::kRparen, + "expects ')' after parameter number") && + AddInstruction( + name, builder->AddInstruction( + HloInstruction::CreateConstant(std::move(literal)))); + } + // Unary ops. + case HloOpcode::kAbs: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kBitcast: + case HloOpcode::kCeil: + case HloOpcode::kCopy: + case HloOpcode::kCos: + case HloOpcode::kExp: + case HloOpcode::kIsFinite: + case HloOpcode::kFloor: + case HloOpcode::kLog: + case HloOpcode::kNot: + case HloOpcode::kNegate: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kSort: + case HloOpcode::kTanh: { + return ParseOperands(&operands, /*expected_size=*/1) && + AddInstruction(name, + builder->AddInstruction(HloInstruction::CreateUnary( + shape, opcode, operands[0]))); + } + // Binary ops. + case HloOpcode::kAdd: + case HloOpcode::kDivide: + case HloOpcode::kMultiply: + case HloOpcode::kSubtract: + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kNe: + case HloOpcode::kDot: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: { + return ParseOperands(&operands, /*expected_size=*/2) && + AddInstruction( + name, builder->AddInstruction(HloInstruction::CreateBinary( + shape, opcode, operands[0], operands[1]))); + } + // Ternary ops. + case HloOpcode::kClamp: + case HloOpcode::kSelect: { + return ParseOperands(&operands, /*expected_size=*/3) && + AddInstruction( + name, + builder->AddInstruction(HloInstruction::CreateTernary( + shape, opcode, operands[0], operands[1], operands[2]))); + } + // Other supported ops. + case HloOpcode::kConvert: { + return ParseOperands(&operands, /*expected_size=*/1) && + AddInstruction( + name, builder->AddInstruction( + HloInstruction::CreateConvert(shape, operands[0]))); + } + case HloOpcode::kCrossReplicaSum: { + return ParseOperands(&operands, /*expected_size=*/1) && + AddInstruction(name, builder->AddInstruction( + HloInstruction::CreateCrossReplicaSum( + shape, operands[0]))); + } + case HloOpcode::kReshape: { + return ParseOperands(&operands, /*expected_size=*/1) && + AddInstruction( + name, builder->AddInstruction( + HloInstruction::CreateReshape(shape, operands[0]))); + } + case HloOpcode::kBroadcast: + case HloOpcode::kCall: + case HloOpcode::kCustomCall: + case HloOpcode::kConcatenate: + case HloOpcode::kReducePrecision: + case HloOpcode::kConvolution: + case HloOpcode::kGetTupleElement: + case HloOpcode::kMap: + case HloOpcode::kPad: + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kReverse: + case HloOpcode::kRng: + case HloOpcode::kSlice: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kTranspose: + case HloOpcode::kTuple: + case HloOpcode::kWhile: + case HloOpcode::kFusion: + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kBatchNormGrad: + case HloOpcode::kRecv: + case HloOpcode::kSend: + case HloOpcode::kUpdate: + case HloOpcode::kIndex: + case HloOpcode::kTrace: + return TokenError(StrCat("parsing not yet implemented for op: ", + HloOpcodeString(opcode))); + } +} + +bool HloParser::ParseLiteral(std::unique_ptr* literal, + const Shape& shape) { + switch (shape.element_type()) { + case PRED: + bool b; + if (!ParseBool(&b)) { + return false; + } + *literal = Literal::CreateR0(b); + return true; + case S32: + int64 i; + if (!ParseInt64(&i)) { + return false; + } + *literal = Literal::CreateR0(i); + return true; + case F32: + double d; + if (!ParseDecimal(&d)) { + return false; + } + *literal = Literal::CreateR0(d); + return true; + default: + return TokenError(StrCat("unsupported constant in shape: ", + ShapeUtil::HumanString(shape))); + } +} + +// operands ::= '(' operands1 ')' +// operands1 +// ::= /*empty*/ +// ::= operand (, operand)* +// operand ::= shape name +bool HloParser::ParseOperands(std::vector* operands, + const int expected_size) { + if (!ParseToken(TokKind::kLparen, + "expects '(' at the beginning of operands")) { + return false; + } + if (lexer_.GetKind() == TokKind::kRparen) { + // empty + } else { + do { + Shape shape; + string name; + if (!ParseShape(&shape) || !ParseName(&name)) { + return false; + } + HloInstruction* instruction = + tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); + if (!instruction) { + return TokenError(StrCat("instruction does not exist: ", name)); + } + operands->push_back(instruction); + } while (EatIfPresent(TokKind::kComma)); + } + if (expected_size != operands->size()) { + return TokenError(StrCat("expects ", expected_size, " operands, but has ", + operands->size(), " operands")); + } + return ParseToken(TokKind::kRparen, "expects ')' at the end of operands"); +} + +// param_list ::= '(' param_list1 ')' +// param_list1 +// ::= /*empty*/ +// ::= param (',' param)* +// param ::= name shape +bool HloParser::ParseParamList() { + if (!ParseToken(TokKind::kLparen, + "expects '(' at the beginning of param list")) { + return false; + } + + if (lexer_.GetKind() == TokKind::kRparen) { + // empty + } else { + do { + Shape shape; + if (!ParseToken(TokKind::kName, "expects name in parameter") || + !ParseShape(&shape)) { + return false; + } + } while (EatIfPresent(TokKind::kComma)); + } + return ParseToken(TokKind::kRparen, "expects ')' at the end of param list"); +} + +// shape ::= shape_val_ +// shape ::= '(' tuple_elements ')' +// tuple_elements +// ::= /*empty*/ +// ::= shape (',' shape)* +bool HloParser::ParseShape(Shape* result) { + if (EatIfPresent(TokKind::kLparen)) { // Tuple + std::vector shapes; + if (lexer_.GetKind() == TokKind::kRparen) { + /*empty*/ + } else { + // shape (',' shape)* + do { + shapes.emplace_back(); + if (!ParseShape(&shapes.back())) { + return false; + } + } while (EatIfPresent(TokKind::kComma)); + } + *result = ShapeUtil::MakeTupleShape(shapes); + return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple."); + } + + if (lexer_.GetKind() != TokKind::kShape) { + return TokenError("expects shape"); + } + *result = lexer_.GetShapeVal(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseName(string* result) { + VLOG(1) << "ParseName"; + if (lexer_.GetKind() != TokKind::kName) { + return TokenError("expects name"); + } + *result = lexer_.GetStrVal(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseOpcode(HloOpcode* result) { + VLOG(1) << "ParseOpcode"; + if (lexer_.GetKind() != TokKind::kOpcode) { + return TokenError("expects opcode"); + } + *result = lexer_.GetOpcodeVal(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseInt64(int64* result) { + VLOG(1) << "ParseInt64"; + if (lexer_.GetKind() != TokKind::kInt) { + return TokenError("expects integer"); + } + *result = lexer_.GetInt64Val(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseDecimal(double* result) { + switch (lexer_.GetKind()) { + case TokKind::kDecimal: + *result = lexer_.GetDecimalVal(); + break; + case TokKind::kInt: + *result = static_cast(lexer_.GetInt64Val()); + break; + default: + return TokenError("expects decimal or integer"); + } + lexer_.Lex(); + return true; +} + +bool HloParser::ParseBool(bool* result) { + if (lexer_.GetKind() != TokKind::kw_true && + lexer_.GetKind() != TokKind::kw_false) { + return TokenError("expects true or false"); + } + *result = lexer_.GetKind() == TokKind::kw_true; + lexer_.Lex(); + return true; +} + +bool HloParser::ParseToken(TokKind kind, const string& msg) { + if (lexer_.GetKind() != kind) { + return TokenError(msg); + } + lexer_.Lex(); + return true; +} + +bool HloParser::EatIfPresent(TokKind kind) { + if (lexer_.GetKind() != kind) { + return false; + } + lexer_.Lex(); + return true; +} + +bool HloParser::AddInstruction(const string& name, + HloInstruction* instruction) { + auto result = instruction_pool_.insert({name, instruction}); + if (!result.second) { + return TokenError(StrCat("instruction already exists: ", name)); + } + return true; +} + +} // namespace + +StatusOr> Parse(StringPiece str) { + HloParser parser(str); + if (!parser.Run()) { + return InvalidArgument("Syntax error: %s", parser.GetError().c_str()); + } + return parser.ConsumeHloModule(); +} + +} // namespace tools +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.h b/tensorflow/compiler/xla/tools/parser/hlo_parser.h new file mode 100644 index 00000000000..9aaf18ef20d --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.h @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace tools { + +// The api of the hlo parser. Given a string in the HloModule::ToString() +// format, returns the parsed HloModule. +StatusOr> Parse(tensorflow::StringPiece str); + +} // namespace tools +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc new file mode 100644 index 00000000000..4ecece3eac1 --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -0,0 +1,240 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +#include +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace tools { +namespace { + +struct TestData { + string test_name; + string module_string; +}; + +string TestDataToString(const ::testing::TestParamInfo& data) { + return data.param.test_name; +} + +std::vector CreateTestCases() { + // clang-format off + return std::vector({ +// ax + y +{ +"AxpyParam", +R"(HloModule axpy_module: + +ENTRY %axpy.v5 (alpha: f32[2,4], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[2,4]{1,0} parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %alpha, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} + +)" +}, +// pred constant +{ +"ConstantPred", +R"(HloModule constant_pred_module: + +ENTRY %constant_pred () -> pred[] { + %constant = pred[] constant(true) +} + +)" +}, +// s32 constant +{ +"ConstantS32", +R"(HloModule constant_s32_module: + +ENTRY %constant_s32 () -> s32[] { + %constant = s32[] constant(-42) +} + +)" +}, +// f32 constant, but the value is not a decimal +{ +"ConstantF32", R"(HloModule ConstantF32_module: + +ENTRY %ConstantF32.v4 () -> f32[] { + %constant = f32[] constant(42) +} + +)" +}, +// constant + constant +{ +"AddConstants", +R"(HloModule add_constants_module: + +ENTRY %add_constants () -> f32[] { + %constant = f32[] constant(3.14) + %add = f32[] add(f32[] %constant, f32[] %constant) +} + +)" +}, +// v1 > v2 ? v1 : v2 +{ +"SelectR1F32", +R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module: + +ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] { + %v1 = f32[4]{0} parameter(0) + %v2 = f32[4]{0} parameter(1) + %greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2) + %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2) +} + +)" +} + }); + // clang-format on +} + +class HloParserTest : public ::testing::Test, + public ::testing::WithParamInterface { + protected: + void ExpectSuccess() { + const string& original = GetParam().module_string; + auto result = Parse(original); + TF_EXPECT_OK(result.status()); + EXPECT_EQ(original, result.ValueOrDie()->ToString()); + } +}; + +TEST_P(HloParserTest, Run) { ExpectSuccess(); } + +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest, + ::testing::ValuesIn(CreateTestCases()), + TestDataToString); + +TEST_F(HloParserTest, Empty) { + const string original = ""; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, Garbage) { + const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, WrongOpcode) { + const string original = R"(HloModule wrong_opcode: + +ENTRY %blabla (x: f32[], y: f32[]) -> f32[] { + %x = f32[]{} parameter(0) + %y = f32[]{} parameter(1) + %le = pred[]{} le(f32[]{} %x, f32[]{} %y) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, WrongShape) { + const string original = R"(HloModule wrong_opcode: + +ENTRY %blabla (x: g32[]) -> g32[] { + %x = g32[]{} parameter(0) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, WrongOperandsSize) { + const string original = R"(HloModule wrong_opcode: + +ENTRY %blabla (x: f32[]) -> pred[] { + %x = f32[]{} parameter(0) + %eq = pred[]{} equal-to(f32[]{} %x) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, OperandNotFound) { + const string original = R"(HloModule operand_not_found: +ENTRY %blabla (x: f32[]) -> pred[] { + %x = f32[]{} parameter(0) + %eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y) +} +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, MoreConstants) { + const string original = R"(HloModule SelectScalarS32True_module: + +ENTRY %SelectScalarS32True.v4 () -> s32[] { + %constant.2 = pred[] constant(true) + %constant.1 = s32[] constant(-42) + %constant = s32[] constant(42) + %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant) +} + +)"; + auto result = Parse(original); + TF_EXPECT_OK(result.status()); + // Constant instructions have no name. The string will be parsed successfully + // but the constant names will not be exactly the same. +} + +TEST_F(HloParserTest, ConstantWithExp) { + const string original = R"(HloModule ConstantWithExp_module: + +ENTRY %ConstantWithExp.v4 () -> f32[] { + %constant.1 = f32[] constant(3e+2) +} + +)"; + auto result = Parse(original); + TF_EXPECT_OK(result.status()); + // The string will be parsed successfully but the output strings are not + // exactly the same, because "3e2" is parsed into value 300 and will be + // printed as "300". +} + +TEST_F(HloParserTest, Tuple) { + const string original = R"(HloModule EmptyTupleCreate_module: + +ENTRY %EmptyTupleCreate.v1 () -> () { + %tuple = () tuple() +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +} // namespace +} // namespace tools +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/tools/parser/hlo_token.h new file mode 100644 index 00000000000..1f75e17c7f0 --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_token.h @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ + +namespace xla { +namespace tools { + +// Defines different kinds of tokens in a hlo module string. +enum class TokKind { + // Markers + kEof, + kError, + + // Tokens with no info. + kEqual, // = + kComma, // , + kColon, // : + kLsquare, + kRsquare, // [ ] + kLbrace, + kRbrace, // { } + kLparen, + kRparen, // ( ) + + kArrow, // -> + + // Keywords + kw_HloModule, + kw_ENTRY, + kw_true, + kw_false, + + // Typed tokens. + kName, // %foo + kShape, // f32[2,3]{1,0} + kOpcode, // add + kInt, // 42 + kDecimal, // 4.2 +}; + +} // namespace tools +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ From 47e92cfd08a230034268a1eeca625fd1e9908616 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Thu, 19 Oct 2017 14:23:03 -0700 Subject: [PATCH 02/41] `tf.py_func`: Handle NumPy arrays of np.object that hold unicode strings. This also fixes a bug affecting `tf.data.Dataset.from_generator()` on Python 3, where the generator yields Unicode (i.e. default) strings. PiperOrigin-RevId: 172798007 --- tensorflow/BUILD | 1 - .../compiler/xla/service/hlo_computation.cc | 4 +- tensorflow/compiler/xla/shape_util.cc | 45 +- tensorflow/compiler/xla/tools/parser/BUILD | 84 --- .../compiler/xla/tools/parser/README.md | 69 --- .../compiler/xla/tools/parser/hlo_lexer.cc | 270 ---------- .../compiler/xla/tools/parser/hlo_lexer.h | 108 ---- .../compiler/xla/tools/parser/hlo_parser.cc | 502 ------------------ .../compiler/xla/tools/parser/hlo_parser.h | 37 -- .../xla/tools/parser/hlo_parser_test.cc | 240 --------- .../compiler/xla/tools/parser/hlo_token.h | 58 -- 11 files changed, 16 insertions(+), 1402 deletions(-) delete mode 100644 tensorflow/compiler/xla/tools/parser/BUILD delete mode 100644 tensorflow/compiler/xla/tools/parser/README.md delete mode 100644 tensorflow/compiler/xla/tools/parser/hlo_lexer.cc delete mode 100644 tensorflow/compiler/xla/tools/parser/hlo_lexer.h delete mode 100644 tensorflow/compiler/xla/tools/parser/hlo_parser.cc delete mode 100644 tensorflow/compiler/xla/tools/parser/hlo_parser.h delete mode 100644 tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc delete mode 100644 tensorflow/compiler/xla/tools/parser/hlo_token.h diff --git a/tensorflow/BUILD b/tensorflow/BUILD index d5c56cdc184..e351037abbd 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -340,7 +340,6 @@ filegroup( "//tensorflow/compiler/xla/service/llvm_ir:all_files", "//tensorflow/compiler/xla/tests:all_files", "//tensorflow/compiler/xla/tools:all_files", - "//tensorflow/compiler/xla/tools/parser:all_files", "//tensorflow/contrib:all_files", "//tensorflow/contrib/all_reduce:all_files", "//tensorflow/contrib/android:all_files", diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 51ead753f04..9b3104eaacd 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -373,8 +373,8 @@ string HloComputation::ToString(int nested_level) const { for (int i = 0; i < nested_level; i++) { s << " "; } - s << "%" << name() << " " << ShapeUtil::HumanString(ComputeProgramShape()) - << " {\n"; + s << name() << " " << ShapeUtil::HumanString(ComputeProgramShape()) + << " { \n"; for (const HloInstruction* instruction : MakeInstructionPostOrder()) { for (int i = 0; i < nested_level; i++) { s << " "; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index af583bed625..8e16056b239 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -102,32 +102,6 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { return true; } -// Constructs and returns the new shape with the given minor_to_major order in -// its Layout. -StatusOr MakeShapeWithLayoutInternal( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice minor_to_major) { - if (dimensions.size() != minor_to_major.size()) { - return InvalidArgument("Dimensions size is %ld, but layout size is %ld.", - dimensions.size(), minor_to_major.size()); - } - if (element_type == OPAQUE || element_type == TUPLE) { - return InvalidArgument("Unsupported element type: %s", - PrimitiveType_Name(element_type).c_str()); - } - Shape shape = ShapeUtil::MakeShape(element_type, dimensions); - auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); - min2maj->Clear(); - for (int64 value : minor_to_major) { - min2maj->Add(value); - } - if (!shape.has_layout()) { - return InvalidArgument("Shape has no layout."); - } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); - return shape; -} - } // namespace /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) { @@ -178,8 +152,16 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ Shape ShapeUtil::MakeShapeWithLayout( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice minor_to_major) { - return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major) - .ValueOrDie(); + CHECK_EQ(dimensions.size(), minor_to_major.size()); + Shape shape = MakeShape(element_type, dimensions); + auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); + min2maj->Clear(); + for (int64 value : minor_to_major) { + min2maj->Add(value); + } + DCHECK(shape.has_layout()); + TF_DCHECK_OK(ValidateShape(shape)); + return shape; } /* static */ Shape ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( @@ -517,10 +499,11 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { // Extract the layout minor-to-major and set it. TF_ASSIGN_OR_RETURN(std::vector min2maj, comma_list_to_int64s(layout_string)); - TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal( - primitive_type, dimensions, min2maj)); + TF_RET_CHECK(dimensions.size() == min2maj.size()); + result = + ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, min2maj); } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result)); + TF_DCHECK_OK(ShapeUtil::ValidateShape(result)); return std::move(result); } diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD deleted file mode 100644 index c84ca9fc833..00000000000 --- a/tensorflow/compiler/xla/tools/parser/BUILD +++ /dev/null @@ -1,84 +0,0 @@ -# Build file for the Hlo parser. - -licenses(["notice"]) # Apache 2.0 - -package( - default_visibility = [":friends"], -) - -package_group( - name = "friends", - includes = [ - "//tensorflow/compiler/xla:friends", - ], -) - -# Filegroup used to collect source files for dependency checking. -filegroup( - name = "c_srcs", - data = glob([ - "**/*.cc", - "**/*.h", - ]), -) - -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -cc_library( - name = "hlo_lexer", - srcs = ["hlo_lexer.cc"], - hdrs = [ - "hlo_lexer.h", - "hlo_token.h", - ], - deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", - ], -) - -cc_library( - name = "hlo_parser", - srcs = ["hlo_parser.cc"], - hdrs = ["hlo_parser.h"], - deps = [ - ":hlo_lexer", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - ], -) - -tf_cc_test( - name = "hlo_parser_test", - size = "small", - srcs = ["hlo_parser_test.cc"], - deps = [ - ":hlo_parser", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md deleted file mode 100644 index a334bc2b297..00000000000 --- a/tensorflow/compiler/xla/tools/parser/README.md +++ /dev/null @@ -1,69 +0,0 @@ -# HloModule string syntax - -TODO: Support subcomputations (for fusion, reduce, while, ...). - -TODO: Support ops that require extra attributes, e.g. dimensions, strides. - -```yacc -hlo_module - : 'HloModule' name computation - ; - -computation - : 'ENTRY' name param_list '->' shape instruction_list - ; - -instruction_list - : '{' instruction_list1 '}' - ; -instruction_list1 - : instruction - | instruction_list1 instruction - ; -instruction - : name '=' shape opcode operands - ; - -operands - : '(' operands1 ')' - ; -operands1 - : /*empty*/ - | operand - | operands1 ',' operand - ; -operand - : shape name - ; - -param_list - : '(' param_list1 ')' - ; -param_list1 - : /*empty*/ - | param - | param_list1 ',' param - ; -param - : name shape - ; - -shape - : shape_val_ - | '(' tuple_elements ')' - ; -tuple_elements - : /*empty*/ - | shape (',' shape)* - ; - -name - : identifier ':' - | '%' identifier - ; - -identifier - : [a-zA-Z_][a-zA-Z0-9_.-]* - ; - -``` diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc deleted file mode 100644 index 3e84ffcbd2c..00000000000 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc +++ /dev/null @@ -1,270 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" - -#include - -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/optional.h" -#include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/platform/regexp.h" - -namespace xla { -namespace tools { - -using tensorflow::StringPiece; - -namespace { - -constexpr int kEOF = -1; -constexpr int kError = -2; - -// [a-zA-Z0-9_.-] -bool IsIdentifierChar(char c) { - return isalnum(static_cast(c)) || c == '-' || c == '.' || - c == '_'; -} - -} // namespace - -int HloLexer::GetNextChar() { - int current_char = PeekCurrentChar(); - if (current_char != kEOF && current_char != kError) { - current_ptr_++; - } - return current_char; -} - -int HloLexer::PeekCurrentChar() const { - if (current_ptr_ == buf_.end()) { - return kEOF; - } - char current_char = *current_ptr_; - if (current_char == 0) { - // '\0' should not appear in the middle of the string. - return kError; - } - return static_cast(current_char); -} - -bool HloLexer::CanDereference(const char* ptr) const { - return ptr < buf_.end() && ptr >= buf_.begin(); -} - -StringPiece HloLexer::StringPieceFromPointers(const char* begin, - const char* end) const { - CHECK(begin <= end); - CHECK(begin == buf_.end() || CanDereference(begin)); - CHECK(end == buf_.end() || CanDereference(end)); - return StringPiece(begin, end - begin); -} - -tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( - const char* begin, const char* end) const { - CHECK(begin <= end); - CHECK(begin == buf_.end() || CanDereference(begin)); - CHECK(end == buf_.end() || CanDereference(end)); - return tensorflow::RegexpStringPiece(begin, end - begin); -} - -TokKind HloLexer::LexToken() { - while (true) { - token_start_ = current_ptr_; - - int current_char = GetNextChar(); - switch (current_char) { - default: - // [a-zA-Z_] - if (isalpha(static_cast(current_char)) || - current_char == '_') { - return LexIdentifier(); - } - return TokKind::kError; - case kEOF: - // Hit the end of the input buffer. - return TokKind::kEof; - case kError: - // Hit an invalid character in the input buffer. - return TokKind::kError; - case ' ': - case '\t': - case '\n': - case '\r': - // Ignore whitespace. - continue; - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - case '-': - if (current_char == '-' && PeekCurrentChar() == '>') { - current_ptr_++; - return TokKind::kArrow; - } - return LexDigitOrNegative(); - case '=': - return TokKind::kEqual; - case ',': - return TokKind::kComma; - case '%': - return LexPercent(); - case ':': - return TokKind::kColon; - case '[': - return TokKind::kLsquare; - case ']': - return TokKind::kRsquare; - case '{': - return TokKind::kLbrace; - case '}': - return TokKind::kRbrace; - case '(': - return TokKind::kLparen; - case ')': - return TokKind::kRparen; - } - } -} - -// Lex a shape, name, keyword, or opcode. -// shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})? -// name ::= [a-zA-Z_][a-zA-Z0-9_.-]*: -// keyword ::= HloModule, ENTRY, ... -// opcode ::= add, greater-than, ... -TokKind HloLexer::LexIdentifier() { - { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); - // 'consumable' will be advanced iff its prefix matches the pattern. - static LazyRE2 shape_pattern = { - R"(^(\w*\d*)\[([\d,]*)\](?:\s*{([\d,]*)})?)"}; - if (RE2::Consume(&consumable, *shape_pattern)) { - auto status_or_shape = ShapeUtil::ParseShapeString( - StringPieceFromPointers(token_start_, consumable.begin())); - if (status_or_shape.ok()) { - // This is a shape string. - shape_val_ = status_or_shape.ValueOrDie(); - current_ptr_ = consumable.begin(); - return TokKind::kShape; - } - } - } - - while (IsIdentifierChar(PeekCurrentChar())) { - current_ptr_++; - } - - // If followed by ':', it's a name. - if (PeekCurrentChar() == ':') { - str_val_.assign(token_start_, current_ptr_); - current_ptr_++; // skip ':' - return TokKind::kName; - } - - StringPiece identifier = StringPieceFromPointers(token_start_, current_ptr_); - - // See if this is a keyword. -#define KEYWORD(STR) \ - do { \ - if (identifier == #STR) { \ - return TokKind::kw_##STR; \ - } \ - } while (false) - - KEYWORD(true); - KEYWORD(false); - KEYWORD(HloModule); - KEYWORD(ENTRY); - -#undef KEYWORD - - // See if this is an opcode. - auto opcode = StringToHloOpcode(identifier.ToString()); - if (opcode.ok()) { - opcode_val_ = opcode.ValueOrDie(); - return TokKind::kOpcode; - } - - current_ptr_ = token_start_ + 1; - return TokKind::kError; -} - -// Lex names after a % character. -// name ::= [a-zA-Z_][a-zA-Z0-9_.-]* -TokKind HloLexer::LexPercent() { - const char* name_start = current_ptr_; - if (isalpha(static_cast(PeekCurrentChar())) || - PeekCurrentChar() == '_') { - current_ptr_++; - while (IsIdentifierChar(PeekCurrentChar())) { - current_ptr_++; - } - str_val_.assign(name_start, current_ptr_); - return TokKind::kName; - } - return TokKind::kError; -} - -// Lex integer and floating-point values. -// int [-]?[0-9]+ -// fp with exp [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+) -// fp without exp [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+) -TokKind HloLexer::LexDigitOrNegative() { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); - static LazyRE2 float_pattern = { - R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|(\d+[.]\d*|\d*[.]\d+))"}; - if (RE2::Consume(&consumable, *float_pattern)) { - current_ptr_ = consumable.begin(); - tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(), - &decimal_val_); - return TokKind::kDecimal; - } - - static LazyRE2 int_pattern = {R"([-]?\d+)"}; - if (RE2::Consume(&consumable, *int_pattern)) { - current_ptr_ = consumable.begin(); - tensorflow::strings::safe_strto64( - StringPieceFromPointers(token_start_, current_ptr_), &int64_val_); - return TokKind::kInt; - } - - return TokKind::kError; -} - -StringPiece HloLexer::GetCurrentLine() const { - const char* start = token_start_; - const char* end = current_ptr_; - if (!CanDereference(start) || !CanDereference(end)) { - return "LINE OUT OF RANGE"; - } - while (start > buf_.begin() && *start != '\n') { - start--; - } - while (end < buf_.end() && *end != '\n') { - end++; - } - return StringPieceFromPointers(start, end); -} - -} // namespace tools -} // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h deleted file mode 100644 index 20278fd6cde..00000000000 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h +++ /dev/null @@ -1,108 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ -#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ - -#include - -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_token.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/regexp.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { -namespace tools { - -// Lexer for the HloModule::ToString() format text. -class HloLexer { - public: - explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) { - current_ptr_ = buf_.begin(); - } - - TokKind Lex() { return current_kind_ = LexToken(); } - TokKind GetKind() const { return current_kind_; } - string GetStrVal() const { - CHECK(GetKind() == TokKind::kName); - return str_val_; - } - Shape GetShapeVal() const { - CHECK(GetKind() == TokKind::kShape); - return shape_val_; - } - HloOpcode GetOpcodeVal() const { - CHECK(GetKind() == TokKind::kOpcode); - return opcode_val_; - } - int64 GetInt64Val() const { - CHECK(GetKind() == TokKind::kInt); - return int64_val_; - } - double GetDecimalVal() const { - CHECK(GetKind() == TokKind::kDecimal); - return decimal_val_; - } - - // Returns the line of text that is currently being lexed. - tensorflow::StringPiece GetCurrentLine() const; - - private: - // Returns the current character. If it's neither the end of input buffer nor - // an invalid character, moves the pointer forward. - int GetNextChar(); - - // Returns the current character. - int PeekCurrentChar() const; - - // Creates StringPiece with the given begin and end. Exits if the begin > end, - // or it's out of the range of the current buffer. - tensorflow::StringPiece StringPieceFromPointers(const char* begin, - const char* end) const; - tensorflow::RegexpStringPiece RegexpStringPieceFromPointers( - const char* begin, const char* end) const; - - // Returns true if the given ptr is dereferenceable within the range of the - // current buffer. - bool CanDereference(const char* ptr) const; - - TokKind LexToken(); - - TokKind LexIdentifier(); - TokKind LexPercent(); - TokKind LexShape(); - TokKind LexConstant(); - TokKind LexDigitOrNegative(); - - const tensorflow::StringPiece buf_; - const char* current_ptr_; - - // Information about the current token. - const char* token_start_; - TokKind current_kind_; - string str_val_; - Shape shape_val_; - HloOpcode opcode_val_; - int64 int64_val_; - double decimal_val_; -}; - -} // namespace tools -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc deleted file mode 100644 index 57700493e6c..00000000000 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ /dev/null @@ -1,502 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" - -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/strcat.h" - -namespace xla { -namespace tools { - -namespace { - -using tensorflow::StringPiece; -using tensorflow::strings::StrCat; - -// Parser for the HloModule::ToString() format text. -class HloParser { - public: - explicit HloParser(StringPiece str) : lexer_(str) {} - - // Runs the parser. Returns false if an error occurred. - bool Run(); - - // Returns the parsed HloModule. - std::unique_ptr ConsumeHloModule() { return std::move(module_); } - - // Returns the error information. - string GetError() const { return tensorflow::str_util::Join(error_, "\n"); } - - private: - // ParseXXX returns false if an error occurred. - bool ParseHloModule(); - bool ParseComputation(); - bool ParseInstructionList(HloComputation::Builder* builder); - bool ParseInstruction(HloComputation::Builder* builder); - bool ParseLiteral(std::unique_ptr* literal, const Shape& shape); - bool ParseOperands(std::vector* operands, - const int expected_size); - bool ParseParamList(); - bool ParseName(string* result); - bool ParseShape(Shape* result); - bool ParseOpcode(HloOpcode* result); - bool ParseInt64(int64* result); - bool ParseDecimal(double* result); - bool ParseBool(bool* result); - bool ParseToken(TokKind kind, const string& msg); - - // Logs the current parsing line and the given message. Always returns false. - bool TokenError(StringPiece msg); - - // If the current token is 'kind', eats it (i.e. lexes the next token) and - // returns true. - bool EatIfPresent(TokKind kind); - - // Adds the instruction to the pool. Returns false and emits an error if the - // instruction already exists. - bool AddInstruction(const string& name, HloInstruction* instruction); - - // The map from the instruction name to the instruction. This does not own the - // instructions. - std::unordered_map instruction_pool_; - - HloLexer lexer_; - std::unique_ptr module_; - std::vector error_; -}; - -bool HloParser::TokenError(StringPiece msg) { - error_.push_back( - StrCat("was parsing \"", lexer_.GetCurrentLine(), "\"; ", msg)); - return false; -} - -bool HloParser::Run() { - lexer_.Lex(); - return ParseHloModule(); -} - -// ::= 'HloModule' name computation -bool HloParser::ParseHloModule() { - if (lexer_.GetKind() != TokKind::kw_HloModule) { - return TokenError("expects HloModule"); - } - // Eat 'HloModule' - lexer_.Lex(); - - string name; - if (!ParseName(&name)) { - return false; - } - - module_ = MakeUnique(name); - - return ParseComputation(); -} - -// computation ::= 'ENTRY' name param_list '->' shape instruction_list -bool HloParser::ParseComputation() { - string name; - if (!ParseToken(TokKind::kw_ENTRY, "expects 'ENTRY'") || !ParseName(&name)) { - return false; - } - auto builder = MakeUnique(name); - - Shape shape; - if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'") || - !ParseShape(&shape) || !ParseInstructionList(builder.get())) { - return false; - } - module_->AddEntryComputation(builder->Build()); - return true; -} - -// instruction_list ::= '{' instruction_list1 '}' -// instruction_list1 ::= (instruction)+ -bool HloParser::ParseInstructionList(HloComputation::Builder* builder) { - if (!ParseToken(TokKind::kLbrace, - "expects '{' at the beginning of instruction list.")) { - return false; - } - do { - if (!ParseInstruction(builder)) { - return false; - } - } while (lexer_.GetKind() != TokKind::kRbrace); - return ParseToken(TokKind::kRbrace, - "expects '}' at the end of instruction list."); -} - -// instruction ::= name '=' shape opcode operands -bool HloParser::ParseInstruction(HloComputation::Builder* builder) { - string name; - Shape shape; - HloOpcode opcode; - std::vector operands; - if (!ParseName(&name) || - !ParseToken(TokKind::kEqual, "expects '=' in instruction") || - !ParseShape(&shape) || !ParseOpcode(&opcode)) { - return false; - } - switch (opcode) { - case HloOpcode::kParameter: { - int64 parameter_number; - return ParseToken(TokKind::kLparen, - "expects '(' before parameter number") && - ParseInt64(¶meter_number) && - ParseToken(TokKind::kRparen, - "expects ')' after parameter number") && - AddInstruction( - name, builder->AddInstruction(HloInstruction::CreateParameter( - parameter_number, shape, name))); - } - case HloOpcode::kConstant: { - std::unique_ptr literal; - return ParseToken(TokKind::kLparen, - "expects '(' before parameter number") && - ParseLiteral(&literal, shape) && - ParseToken(TokKind::kRparen, - "expects ')' after parameter number") && - AddInstruction( - name, builder->AddInstruction( - HloInstruction::CreateConstant(std::move(literal)))); - } - // Unary ops. - case HloOpcode::kAbs: - case HloOpcode::kRoundNearestAfz: - case HloOpcode::kBitcast: - case HloOpcode::kCeil: - case HloOpcode::kCopy: - case HloOpcode::kCos: - case HloOpcode::kExp: - case HloOpcode::kIsFinite: - case HloOpcode::kFloor: - case HloOpcode::kLog: - case HloOpcode::kNot: - case HloOpcode::kNegate: - case HloOpcode::kSign: - case HloOpcode::kSin: - case HloOpcode::kSort: - case HloOpcode::kTanh: { - return ParseOperands(&operands, /*expected_size=*/1) && - AddInstruction(name, - builder->AddInstruction(HloInstruction::CreateUnary( - shape, opcode, operands[0]))); - } - // Binary ops. - case HloOpcode::kAdd: - case HloOpcode::kDivide: - case HloOpcode::kMultiply: - case HloOpcode::kSubtract: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kNe: - case HloOpcode::kDot: - case HloOpcode::kMaximum: - case HloOpcode::kMinimum: - case HloOpcode::kPower: - case HloOpcode::kRemainder: - case HloOpcode::kAnd: - case HloOpcode::kOr: - case HloOpcode::kShiftLeft: - case HloOpcode::kShiftRightArithmetic: - case HloOpcode::kShiftRightLogical: { - return ParseOperands(&operands, /*expected_size=*/2) && - AddInstruction( - name, builder->AddInstruction(HloInstruction::CreateBinary( - shape, opcode, operands[0], operands[1]))); - } - // Ternary ops. - case HloOpcode::kClamp: - case HloOpcode::kSelect: { - return ParseOperands(&operands, /*expected_size=*/3) && - AddInstruction( - name, - builder->AddInstruction(HloInstruction::CreateTernary( - shape, opcode, operands[0], operands[1], operands[2]))); - } - // Other supported ops. - case HloOpcode::kConvert: { - return ParseOperands(&operands, /*expected_size=*/1) && - AddInstruction( - name, builder->AddInstruction( - HloInstruction::CreateConvert(shape, operands[0]))); - } - case HloOpcode::kCrossReplicaSum: { - return ParseOperands(&operands, /*expected_size=*/1) && - AddInstruction(name, builder->AddInstruction( - HloInstruction::CreateCrossReplicaSum( - shape, operands[0]))); - } - case HloOpcode::kReshape: { - return ParseOperands(&operands, /*expected_size=*/1) && - AddInstruction( - name, builder->AddInstruction( - HloInstruction::CreateReshape(shape, operands[0]))); - } - case HloOpcode::kBroadcast: - case HloOpcode::kCall: - case HloOpcode::kCustomCall: - case HloOpcode::kConcatenate: - case HloOpcode::kReducePrecision: - case HloOpcode::kConvolution: - case HloOpcode::kGetTupleElement: - case HloOpcode::kMap: - case HloOpcode::kPad: - case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: - case HloOpcode::kSelectAndScatter: - case HloOpcode::kReverse: - case HloOpcode::kRng: - case HloOpcode::kSlice: - case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: - case HloOpcode::kTranspose: - case HloOpcode::kTuple: - case HloOpcode::kWhile: - case HloOpcode::kFusion: - case HloOpcode::kBatchNormTraining: - case HloOpcode::kBatchNormInference: - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: - case HloOpcode::kBatchNormGrad: - case HloOpcode::kRecv: - case HloOpcode::kSend: - case HloOpcode::kUpdate: - case HloOpcode::kIndex: - case HloOpcode::kTrace: - return TokenError(StrCat("parsing not yet implemented for op: ", - HloOpcodeString(opcode))); - } -} - -bool HloParser::ParseLiteral(std::unique_ptr* literal, - const Shape& shape) { - switch (shape.element_type()) { - case PRED: - bool b; - if (!ParseBool(&b)) { - return false; - } - *literal = Literal::CreateR0(b); - return true; - case S32: - int64 i; - if (!ParseInt64(&i)) { - return false; - } - *literal = Literal::CreateR0(i); - return true; - case F32: - double d; - if (!ParseDecimal(&d)) { - return false; - } - *literal = Literal::CreateR0(d); - return true; - default: - return TokenError(StrCat("unsupported constant in shape: ", - ShapeUtil::HumanString(shape))); - } -} - -// operands ::= '(' operands1 ')' -// operands1 -// ::= /*empty*/ -// ::= operand (, operand)* -// operand ::= shape name -bool HloParser::ParseOperands(std::vector* operands, - const int expected_size) { - if (!ParseToken(TokKind::kLparen, - "expects '(' at the beginning of operands")) { - return false; - } - if (lexer_.GetKind() == TokKind::kRparen) { - // empty - } else { - do { - Shape shape; - string name; - if (!ParseShape(&shape) || !ParseName(&name)) { - return false; - } - HloInstruction* instruction = - tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); - if (!instruction) { - return TokenError(StrCat("instruction does not exist: ", name)); - } - operands->push_back(instruction); - } while (EatIfPresent(TokKind::kComma)); - } - if (expected_size != operands->size()) { - return TokenError(StrCat("expects ", expected_size, " operands, but has ", - operands->size(), " operands")); - } - return ParseToken(TokKind::kRparen, "expects ')' at the end of operands"); -} - -// param_list ::= '(' param_list1 ')' -// param_list1 -// ::= /*empty*/ -// ::= param (',' param)* -// param ::= name shape -bool HloParser::ParseParamList() { - if (!ParseToken(TokKind::kLparen, - "expects '(' at the beginning of param list")) { - return false; - } - - if (lexer_.GetKind() == TokKind::kRparen) { - // empty - } else { - do { - Shape shape; - if (!ParseToken(TokKind::kName, "expects name in parameter") || - !ParseShape(&shape)) { - return false; - } - } while (EatIfPresent(TokKind::kComma)); - } - return ParseToken(TokKind::kRparen, "expects ')' at the end of param list"); -} - -// shape ::= shape_val_ -// shape ::= '(' tuple_elements ')' -// tuple_elements -// ::= /*empty*/ -// ::= shape (',' shape)* -bool HloParser::ParseShape(Shape* result) { - if (EatIfPresent(TokKind::kLparen)) { // Tuple - std::vector shapes; - if (lexer_.GetKind() == TokKind::kRparen) { - /*empty*/ - } else { - // shape (',' shape)* - do { - shapes.emplace_back(); - if (!ParseShape(&shapes.back())) { - return false; - } - } while (EatIfPresent(TokKind::kComma)); - } - *result = ShapeUtil::MakeTupleShape(shapes); - return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple."); - } - - if (lexer_.GetKind() != TokKind::kShape) { - return TokenError("expects shape"); - } - *result = lexer_.GetShapeVal(); - lexer_.Lex(); - return true; -} - -bool HloParser::ParseName(string* result) { - VLOG(1) << "ParseName"; - if (lexer_.GetKind() != TokKind::kName) { - return TokenError("expects name"); - } - *result = lexer_.GetStrVal(); - lexer_.Lex(); - return true; -} - -bool HloParser::ParseOpcode(HloOpcode* result) { - VLOG(1) << "ParseOpcode"; - if (lexer_.GetKind() != TokKind::kOpcode) { - return TokenError("expects opcode"); - } - *result = lexer_.GetOpcodeVal(); - lexer_.Lex(); - return true; -} - -bool HloParser::ParseInt64(int64* result) { - VLOG(1) << "ParseInt64"; - if (lexer_.GetKind() != TokKind::kInt) { - return TokenError("expects integer"); - } - *result = lexer_.GetInt64Val(); - lexer_.Lex(); - return true; -} - -bool HloParser::ParseDecimal(double* result) { - switch (lexer_.GetKind()) { - case TokKind::kDecimal: - *result = lexer_.GetDecimalVal(); - break; - case TokKind::kInt: - *result = static_cast(lexer_.GetInt64Val()); - break; - default: - return TokenError("expects decimal or integer"); - } - lexer_.Lex(); - return true; -} - -bool HloParser::ParseBool(bool* result) { - if (lexer_.GetKind() != TokKind::kw_true && - lexer_.GetKind() != TokKind::kw_false) { - return TokenError("expects true or false"); - } - *result = lexer_.GetKind() == TokKind::kw_true; - lexer_.Lex(); - return true; -} - -bool HloParser::ParseToken(TokKind kind, const string& msg) { - if (lexer_.GetKind() != kind) { - return TokenError(msg); - } - lexer_.Lex(); - return true; -} - -bool HloParser::EatIfPresent(TokKind kind) { - if (lexer_.GetKind() != kind) { - return false; - } - lexer_.Lex(); - return true; -} - -bool HloParser::AddInstruction(const string& name, - HloInstruction* instruction) { - auto result = instruction_pool_.insert({name, instruction}); - if (!result.second) { - return TokenError(StrCat("instruction already exists: ", name)); - } - return true; -} - -} // namespace - -StatusOr> Parse(StringPiece str) { - HloParser parser(str); - if (!parser.Run()) { - return InvalidArgument("Syntax error: %s", parser.GetError().c_str()); - } - return parser.ConsumeHloModule(); -} - -} // namespace tools -} // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.h b/tensorflow/compiler/xla/tools/parser/hlo_parser.h deleted file mode 100644 index 9aaf18ef20d..00000000000 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.h +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ -#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { -namespace tools { - -// The api of the hlo parser. Given a string in the HloModule::ToString() -// format, returns the parsed HloModule. -StatusOr> Parse(tensorflow::StringPiece str); - -} // namespace tools -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc deleted file mode 100644 index 4ecece3eac1..00000000000 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ /dev/null @@ -1,240 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" - -#include -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/test.h" - -namespace xla { -namespace tools { -namespace { - -struct TestData { - string test_name; - string module_string; -}; - -string TestDataToString(const ::testing::TestParamInfo& data) { - return data.param.test_name; -} - -std::vector CreateTestCases() { - // clang-format off - return std::vector({ -// ax + y -{ -"AxpyParam", -R"(HloModule axpy_module: - -ENTRY %axpy.v5 (alpha: f32[2,4], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { - %alpha = f32[2,4]{1,0} parameter(0) - %x = f32[2,4]{1,0} parameter(1) - %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %alpha, f32[2,4]{1,0} %x) - %y = f32[2,4]{1,0} parameter(2) - %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) -} - -)" -}, -// pred constant -{ -"ConstantPred", -R"(HloModule constant_pred_module: - -ENTRY %constant_pred () -> pred[] { - %constant = pred[] constant(true) -} - -)" -}, -// s32 constant -{ -"ConstantS32", -R"(HloModule constant_s32_module: - -ENTRY %constant_s32 () -> s32[] { - %constant = s32[] constant(-42) -} - -)" -}, -// f32 constant, but the value is not a decimal -{ -"ConstantF32", R"(HloModule ConstantF32_module: - -ENTRY %ConstantF32.v4 () -> f32[] { - %constant = f32[] constant(42) -} - -)" -}, -// constant + constant -{ -"AddConstants", -R"(HloModule add_constants_module: - -ENTRY %add_constants () -> f32[] { - %constant = f32[] constant(3.14) - %add = f32[] add(f32[] %constant, f32[] %constant) -} - -)" -}, -// v1 > v2 ? v1 : v2 -{ -"SelectR1F32", -R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module: - -ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] { - %v1 = f32[4]{0} parameter(0) - %v2 = f32[4]{0} parameter(1) - %greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2) - %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2) -} - -)" -} - }); - // clang-format on -} - -class HloParserTest : public ::testing::Test, - public ::testing::WithParamInterface { - protected: - void ExpectSuccess() { - const string& original = GetParam().module_string; - auto result = Parse(original); - TF_EXPECT_OK(result.status()); - EXPECT_EQ(original, result.ValueOrDie()->ToString()); - } -}; - -TEST_P(HloParserTest, Run) { ExpectSuccess(); } - -INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest, - ::testing::ValuesIn(CreateTestCases()), - TestDataToString); - -TEST_F(HloParserTest, Empty) { - const string original = ""; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); -} - -TEST_F(HloParserTest, Garbage) { - const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); -} - -TEST_F(HloParserTest, WrongOpcode) { - const string original = R"(HloModule wrong_opcode: - -ENTRY %blabla (x: f32[], y: f32[]) -> f32[] { - %x = f32[]{} parameter(0) - %y = f32[]{} parameter(1) - %le = pred[]{} le(f32[]{} %x, f32[]{} %y) -} - -)"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); -} - -TEST_F(HloParserTest, WrongShape) { - const string original = R"(HloModule wrong_opcode: - -ENTRY %blabla (x: g32[]) -> g32[] { - %x = g32[]{} parameter(0) -} - -)"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); -} - -TEST_F(HloParserTest, WrongOperandsSize) { - const string original = R"(HloModule wrong_opcode: - -ENTRY %blabla (x: f32[]) -> pred[] { - %x = f32[]{} parameter(0) - %eq = pred[]{} equal-to(f32[]{} %x) -} - -)"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); -} - -TEST_F(HloParserTest, OperandNotFound) { - const string original = R"(HloModule operand_not_found: -ENTRY %blabla (x: f32[]) -> pred[] { - %x = f32[]{} parameter(0) - %eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y) -} -)"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); -} - -TEST_F(HloParserTest, MoreConstants) { - const string original = R"(HloModule SelectScalarS32True_module: - -ENTRY %SelectScalarS32True.v4 () -> s32[] { - %constant.2 = pred[] constant(true) - %constant.1 = s32[] constant(-42) - %constant = s32[] constant(42) - %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant) -} - -)"; - auto result = Parse(original); - TF_EXPECT_OK(result.status()); - // Constant instructions have no name. The string will be parsed successfully - // but the constant names will not be exactly the same. -} - -TEST_F(HloParserTest, ConstantWithExp) { - const string original = R"(HloModule ConstantWithExp_module: - -ENTRY %ConstantWithExp.v4 () -> f32[] { - %constant.1 = f32[] constant(3e+2) -} - -)"; - auto result = Parse(original); - TF_EXPECT_OK(result.status()); - // The string will be parsed successfully but the output strings are not - // exactly the same, because "3e2" is parsed into value 300 and will be - // printed as "300". -} - -TEST_F(HloParserTest, Tuple) { - const string original = R"(HloModule EmptyTupleCreate_module: - -ENTRY %EmptyTupleCreate.v1 () -> () { - %tuple = () tuple() -} - -)"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); -} - -} // namespace -} // namespace tools -} // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/tools/parser/hlo_token.h deleted file mode 100644 index 1f75e17c7f0..00000000000 --- a/tensorflow/compiler/xla/tools/parser/hlo_token.h +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ -#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ - -namespace xla { -namespace tools { - -// Defines different kinds of tokens in a hlo module string. -enum class TokKind { - // Markers - kEof, - kError, - - // Tokens with no info. - kEqual, // = - kComma, // , - kColon, // : - kLsquare, - kRsquare, // [ ] - kLbrace, - kRbrace, // { } - kLparen, - kRparen, // ( ) - - kArrow, // -> - - // Keywords - kw_HloModule, - kw_ENTRY, - kw_true, - kw_false, - - // Typed tokens. - kName, // %foo - kShape, // f32[2,3]{1,0} - kOpcode, // add - kInt, // 42 - kDecimal, // 4.2 -}; - -} // namespace tools -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ From 6c074971ab80362954bea07ff2896cb91636b787 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Oct 2017 15:02:01 -0700 Subject: [PATCH 03/41] Add a recursive descent parser for the HloModule string. It constructs an HloModule object from a string printed by HloModule::ToString(). This is a initial stage. It currently supports: - unary, binary, ternary ops, and other ops that don't have extra attributes. - module with entry computation only. - simple cases for constant instruction. To make the parser simpler, this cl removes a whitespace and adds a '%' before the computation name in HloComputation::ToString(). Further steps will enable parsing subcomputations, more cases of constants, tuple, and ops that require extra attributes (e.g., broadcast dimensions, subcomputation). PiperOrigin-RevId: 172804214 --- tensorflow/BUILD | 1 + .../compiler/xla/service/hlo_computation.cc | 4 +- tensorflow/compiler/xla/shape_util.cc | 45 +- tensorflow/compiler/xla/tools/parser/BUILD | 84 +++ .../compiler/xla/tools/parser/README.md | 69 +++ .../compiler/xla/tools/parser/hlo_lexer.cc | 270 ++++++++++ .../compiler/xla/tools/parser/hlo_lexer.h | 108 ++++ .../compiler/xla/tools/parser/hlo_parser.cc | 502 ++++++++++++++++++ .../compiler/xla/tools/parser/hlo_parser.h | 37 ++ .../xla/tools/parser/hlo_parser_test.cc | 240 +++++++++ .../compiler/xla/tools/parser/hlo_token.h | 58 ++ 11 files changed, 1402 insertions(+), 16 deletions(-) create mode 100644 tensorflow/compiler/xla/tools/parser/BUILD create mode 100644 tensorflow/compiler/xla/tools/parser/README.md create mode 100644 tensorflow/compiler/xla/tools/parser/hlo_lexer.cc create mode 100644 tensorflow/compiler/xla/tools/parser/hlo_lexer.h create mode 100644 tensorflow/compiler/xla/tools/parser/hlo_parser.cc create mode 100644 tensorflow/compiler/xla/tools/parser/hlo_parser.h create mode 100644 tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc create mode 100644 tensorflow/compiler/xla/tools/parser/hlo_token.h diff --git a/tensorflow/BUILD b/tensorflow/BUILD index e351037abbd..d5c56cdc184 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -340,6 +340,7 @@ filegroup( "//tensorflow/compiler/xla/service/llvm_ir:all_files", "//tensorflow/compiler/xla/tests:all_files", "//tensorflow/compiler/xla/tools:all_files", + "//tensorflow/compiler/xla/tools/parser:all_files", "//tensorflow/contrib:all_files", "//tensorflow/contrib/all_reduce:all_files", "//tensorflow/contrib/android:all_files", diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 9b3104eaacd..51ead753f04 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -373,8 +373,8 @@ string HloComputation::ToString(int nested_level) const { for (int i = 0; i < nested_level; i++) { s << " "; } - s << name() << " " << ShapeUtil::HumanString(ComputeProgramShape()) - << " { \n"; + s << "%" << name() << " " << ShapeUtil::HumanString(ComputeProgramShape()) + << " {\n"; for (const HloInstruction* instruction : MakeInstructionPostOrder()) { for (int i = 0; i < nested_level; i++) { s << " "; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 8e16056b239..af583bed625 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -102,6 +102,32 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { return true; } +// Constructs and returns the new shape with the given minor_to_major order in +// its Layout. +StatusOr MakeShapeWithLayoutInternal( + PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice minor_to_major) { + if (dimensions.size() != minor_to_major.size()) { + return InvalidArgument("Dimensions size is %ld, but layout size is %ld.", + dimensions.size(), minor_to_major.size()); + } + if (element_type == OPAQUE || element_type == TUPLE) { + return InvalidArgument("Unsupported element type: %s", + PrimitiveType_Name(element_type).c_str()); + } + Shape shape = ShapeUtil::MakeShape(element_type, dimensions); + auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); + min2maj->Clear(); + for (int64 value : minor_to_major) { + min2maj->Add(value); + } + if (!shape.has_layout()) { + return InvalidArgument("Shape has no layout."); + } + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); + return shape; +} + } // namespace /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) { @@ -152,16 +178,8 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { /* static */ Shape ShapeUtil::MakeShapeWithLayout( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice minor_to_major) { - CHECK_EQ(dimensions.size(), minor_to_major.size()); - Shape shape = MakeShape(element_type, dimensions); - auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); - min2maj->Clear(); - for (int64 value : minor_to_major) { - min2maj->Add(value); - } - DCHECK(shape.has_layout()); - TF_DCHECK_OK(ValidateShape(shape)); - return shape; + return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major) + .ValueOrDie(); } /* static */ Shape ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( @@ -499,11 +517,10 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { // Extract the layout minor-to-major and set it. TF_ASSIGN_OR_RETURN(std::vector min2maj, comma_list_to_int64s(layout_string)); - TF_RET_CHECK(dimensions.size() == min2maj.size()); - result = - ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, min2maj); + TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal( + primitive_type, dimensions, min2maj)); } - TF_DCHECK_OK(ShapeUtil::ValidateShape(result)); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result)); return std::move(result); } diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD new file mode 100644 index 00000000000..c84ca9fc833 --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/BUILD @@ -0,0 +1,84 @@ +# Build file for the Hlo parser. + +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [":friends"], +) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +cc_library( + name = "hlo_lexer", + srcs = ["hlo_lexer.cc"], + hdrs = [ + "hlo_lexer.h", + "hlo_token.h", + ], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + ], +) + +cc_library( + name = "hlo_parser", + srcs = ["hlo_parser.cc"], + hdrs = ["hlo_parser.h"], + deps = [ + ":hlo_lexer", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_cc_test( + name = "hlo_parser_test", + size = "small", + srcs = ["hlo_parser_test.cc"], + deps = [ + ":hlo_parser", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md new file mode 100644 index 00000000000..a334bc2b297 --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/README.md @@ -0,0 +1,69 @@ +# HloModule string syntax + +TODO: Support subcomputations (for fusion, reduce, while, ...). + +TODO: Support ops that require extra attributes, e.g. dimensions, strides. + +```yacc +hlo_module + : 'HloModule' name computation + ; + +computation + : 'ENTRY' name param_list '->' shape instruction_list + ; + +instruction_list + : '{' instruction_list1 '}' + ; +instruction_list1 + : instruction + | instruction_list1 instruction + ; +instruction + : name '=' shape opcode operands + ; + +operands + : '(' operands1 ')' + ; +operands1 + : /*empty*/ + | operand + | operands1 ',' operand + ; +operand + : shape name + ; + +param_list + : '(' param_list1 ')' + ; +param_list1 + : /*empty*/ + | param + | param_list1 ',' param + ; +param + : name shape + ; + +shape + : shape_val_ + | '(' tuple_elements ')' + ; +tuple_elements + : /*empty*/ + | shape (',' shape)* + ; + +name + : identifier ':' + | '%' identifier + ; + +identifier + : [a-zA-Z_][a-zA-Z0-9_.-]* + ; + +``` diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc new file mode 100644 index 00000000000..3e84ffcbd2c --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc @@ -0,0 +1,270 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" + +#include + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/platform/regexp.h" + +namespace xla { +namespace tools { + +using tensorflow::StringPiece; + +namespace { + +constexpr int kEOF = -1; +constexpr int kError = -2; + +// [a-zA-Z0-9_.-] +bool IsIdentifierChar(char c) { + return isalnum(static_cast(c)) || c == '-' || c == '.' || + c == '_'; +} + +} // namespace + +int HloLexer::GetNextChar() { + int current_char = PeekCurrentChar(); + if (current_char != kEOF && current_char != kError) { + current_ptr_++; + } + return current_char; +} + +int HloLexer::PeekCurrentChar() const { + if (current_ptr_ == buf_.end()) { + return kEOF; + } + char current_char = *current_ptr_; + if (current_char == 0) { + // '\0' should not appear in the middle of the string. + return kError; + } + return static_cast(current_char); +} + +bool HloLexer::CanDereference(const char* ptr) const { + return ptr < buf_.end() && ptr >= buf_.begin(); +} + +StringPiece HloLexer::StringPieceFromPointers(const char* begin, + const char* end) const { + CHECK(begin <= end); + CHECK(begin == buf_.end() || CanDereference(begin)); + CHECK(end == buf_.end() || CanDereference(end)); + return StringPiece(begin, end - begin); +} + +tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( + const char* begin, const char* end) const { + CHECK(begin <= end); + CHECK(begin == buf_.end() || CanDereference(begin)); + CHECK(end == buf_.end() || CanDereference(end)); + return tensorflow::RegexpStringPiece(begin, end - begin); +} + +TokKind HloLexer::LexToken() { + while (true) { + token_start_ = current_ptr_; + + int current_char = GetNextChar(); + switch (current_char) { + default: + // [a-zA-Z_] + if (isalpha(static_cast(current_char)) || + current_char == '_') { + return LexIdentifier(); + } + return TokKind::kError; + case kEOF: + // Hit the end of the input buffer. + return TokKind::kEof; + case kError: + // Hit an invalid character in the input buffer. + return TokKind::kError; + case ' ': + case '\t': + case '\n': + case '\r': + // Ignore whitespace. + continue; + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + case '-': + if (current_char == '-' && PeekCurrentChar() == '>') { + current_ptr_++; + return TokKind::kArrow; + } + return LexDigitOrNegative(); + case '=': + return TokKind::kEqual; + case ',': + return TokKind::kComma; + case '%': + return LexPercent(); + case ':': + return TokKind::kColon; + case '[': + return TokKind::kLsquare; + case ']': + return TokKind::kRsquare; + case '{': + return TokKind::kLbrace; + case '}': + return TokKind::kRbrace; + case '(': + return TokKind::kLparen; + case ')': + return TokKind::kRparen; + } + } +} + +// Lex a shape, name, keyword, or opcode. +// shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})? +// name ::= [a-zA-Z_][a-zA-Z0-9_.-]*: +// keyword ::= HloModule, ENTRY, ... +// opcode ::= add, greater-than, ... +TokKind HloLexer::LexIdentifier() { + { + auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + // 'consumable' will be advanced iff its prefix matches the pattern. + static LazyRE2 shape_pattern = { + R"(^(\w*\d*)\[([\d,]*)\](?:\s*{([\d,]*)})?)"}; + if (RE2::Consume(&consumable, *shape_pattern)) { + auto status_or_shape = ShapeUtil::ParseShapeString( + StringPieceFromPointers(token_start_, consumable.begin())); + if (status_or_shape.ok()) { + // This is a shape string. + shape_val_ = status_or_shape.ValueOrDie(); + current_ptr_ = consumable.begin(); + return TokKind::kShape; + } + } + } + + while (IsIdentifierChar(PeekCurrentChar())) { + current_ptr_++; + } + + // If followed by ':', it's a name. + if (PeekCurrentChar() == ':') { + str_val_.assign(token_start_, current_ptr_); + current_ptr_++; // skip ':' + return TokKind::kName; + } + + StringPiece identifier = StringPieceFromPointers(token_start_, current_ptr_); + + // See if this is a keyword. +#define KEYWORD(STR) \ + do { \ + if (identifier == #STR) { \ + return TokKind::kw_##STR; \ + } \ + } while (false) + + KEYWORD(true); + KEYWORD(false); + KEYWORD(HloModule); + KEYWORD(ENTRY); + +#undef KEYWORD + + // See if this is an opcode. + auto opcode = StringToHloOpcode(identifier.ToString()); + if (opcode.ok()) { + opcode_val_ = opcode.ValueOrDie(); + return TokKind::kOpcode; + } + + current_ptr_ = token_start_ + 1; + return TokKind::kError; +} + +// Lex names after a % character. +// name ::= [a-zA-Z_][a-zA-Z0-9_.-]* +TokKind HloLexer::LexPercent() { + const char* name_start = current_ptr_; + if (isalpha(static_cast(PeekCurrentChar())) || + PeekCurrentChar() == '_') { + current_ptr_++; + while (IsIdentifierChar(PeekCurrentChar())) { + current_ptr_++; + } + str_val_.assign(name_start, current_ptr_); + return TokKind::kName; + } + return TokKind::kError; +} + +// Lex integer and floating-point values. +// int [-]?[0-9]+ +// fp with exp [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+) +// fp without exp [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+) +TokKind HloLexer::LexDigitOrNegative() { + auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + static LazyRE2 float_pattern = { + R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|(\d+[.]\d*|\d*[.]\d+))"}; + if (RE2::Consume(&consumable, *float_pattern)) { + current_ptr_ = consumable.begin(); + tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(), + &decimal_val_); + return TokKind::kDecimal; + } + + static LazyRE2 int_pattern = {R"([-]?\d+)"}; + if (RE2::Consume(&consumable, *int_pattern)) { + current_ptr_ = consumable.begin(); + tensorflow::strings::safe_strto64( + StringPieceFromPointers(token_start_, current_ptr_), &int64_val_); + return TokKind::kInt; + } + + return TokKind::kError; +} + +StringPiece HloLexer::GetCurrentLine() const { + const char* start = token_start_; + const char* end = current_ptr_; + if (!CanDereference(start) || !CanDereference(end)) { + return "LINE OUT OF RANGE"; + } + while (start > buf_.begin() && *start != '\n') { + start--; + } + while (end < buf_.end() && *end != '\n') { + end++; + } + return StringPieceFromPointers(start, end); +} + +} // namespace tools +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h new file mode 100644 index 00000000000..20278fd6cde --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h @@ -0,0 +1,108 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_token.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace tools { + +// Lexer for the HloModule::ToString() format text. +class HloLexer { + public: + explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) { + current_ptr_ = buf_.begin(); + } + + TokKind Lex() { return current_kind_ = LexToken(); } + TokKind GetKind() const { return current_kind_; } + string GetStrVal() const { + CHECK(GetKind() == TokKind::kName); + return str_val_; + } + Shape GetShapeVal() const { + CHECK(GetKind() == TokKind::kShape); + return shape_val_; + } + HloOpcode GetOpcodeVal() const { + CHECK(GetKind() == TokKind::kOpcode); + return opcode_val_; + } + int64 GetInt64Val() const { + CHECK(GetKind() == TokKind::kInt); + return int64_val_; + } + double GetDecimalVal() const { + CHECK(GetKind() == TokKind::kDecimal); + return decimal_val_; + } + + // Returns the line of text that is currently being lexed. + tensorflow::StringPiece GetCurrentLine() const; + + private: + // Returns the current character. If it's neither the end of input buffer nor + // an invalid character, moves the pointer forward. + int GetNextChar(); + + // Returns the current character. + int PeekCurrentChar() const; + + // Creates StringPiece with the given begin and end. Exits if the begin > end, + // or it's out of the range of the current buffer. + tensorflow::StringPiece StringPieceFromPointers(const char* begin, + const char* end) const; + tensorflow::RegexpStringPiece RegexpStringPieceFromPointers( + const char* begin, const char* end) const; + + // Returns true if the given ptr is dereferenceable within the range of the + // current buffer. + bool CanDereference(const char* ptr) const; + + TokKind LexToken(); + + TokKind LexIdentifier(); + TokKind LexPercent(); + TokKind LexShape(); + TokKind LexConstant(); + TokKind LexDigitOrNegative(); + + const tensorflow::StringPiece buf_; + const char* current_ptr_; + + // Information about the current token. + const char* token_start_; + TokKind current_kind_; + string str_val_; + Shape shape_val_; + HloOpcode opcode_val_; + int64 int64_val_; + double decimal_val_; +}; + +} // namespace tools +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc new file mode 100644 index 00000000000..57700493e6c --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -0,0 +1,502 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { +namespace tools { + +namespace { + +using tensorflow::StringPiece; +using tensorflow::strings::StrCat; + +// Parser for the HloModule::ToString() format text. +class HloParser { + public: + explicit HloParser(StringPiece str) : lexer_(str) {} + + // Runs the parser. Returns false if an error occurred. + bool Run(); + + // Returns the parsed HloModule. + std::unique_ptr ConsumeHloModule() { return std::move(module_); } + + // Returns the error information. + string GetError() const { return tensorflow::str_util::Join(error_, "\n"); } + + private: + // ParseXXX returns false if an error occurred. + bool ParseHloModule(); + bool ParseComputation(); + bool ParseInstructionList(HloComputation::Builder* builder); + bool ParseInstruction(HloComputation::Builder* builder); + bool ParseLiteral(std::unique_ptr* literal, const Shape& shape); + bool ParseOperands(std::vector* operands, + const int expected_size); + bool ParseParamList(); + bool ParseName(string* result); + bool ParseShape(Shape* result); + bool ParseOpcode(HloOpcode* result); + bool ParseInt64(int64* result); + bool ParseDecimal(double* result); + bool ParseBool(bool* result); + bool ParseToken(TokKind kind, const string& msg); + + // Logs the current parsing line and the given message. Always returns false. + bool TokenError(StringPiece msg); + + // If the current token is 'kind', eats it (i.e. lexes the next token) and + // returns true. + bool EatIfPresent(TokKind kind); + + // Adds the instruction to the pool. Returns false and emits an error if the + // instruction already exists. + bool AddInstruction(const string& name, HloInstruction* instruction); + + // The map from the instruction name to the instruction. This does not own the + // instructions. + std::unordered_map instruction_pool_; + + HloLexer lexer_; + std::unique_ptr module_; + std::vector error_; +}; + +bool HloParser::TokenError(StringPiece msg) { + error_.push_back( + StrCat("was parsing \"", lexer_.GetCurrentLine(), "\"; ", msg)); + return false; +} + +bool HloParser::Run() { + lexer_.Lex(); + return ParseHloModule(); +} + +// ::= 'HloModule' name computation +bool HloParser::ParseHloModule() { + if (lexer_.GetKind() != TokKind::kw_HloModule) { + return TokenError("expects HloModule"); + } + // Eat 'HloModule' + lexer_.Lex(); + + string name; + if (!ParseName(&name)) { + return false; + } + + module_ = MakeUnique(name); + + return ParseComputation(); +} + +// computation ::= 'ENTRY' name param_list '->' shape instruction_list +bool HloParser::ParseComputation() { + string name; + if (!ParseToken(TokKind::kw_ENTRY, "expects 'ENTRY'") || !ParseName(&name)) { + return false; + } + auto builder = MakeUnique(name); + + Shape shape; + if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'") || + !ParseShape(&shape) || !ParseInstructionList(builder.get())) { + return false; + } + module_->AddEntryComputation(builder->Build()); + return true; +} + +// instruction_list ::= '{' instruction_list1 '}' +// instruction_list1 ::= (instruction)+ +bool HloParser::ParseInstructionList(HloComputation::Builder* builder) { + if (!ParseToken(TokKind::kLbrace, + "expects '{' at the beginning of instruction list.")) { + return false; + } + do { + if (!ParseInstruction(builder)) { + return false; + } + } while (lexer_.GetKind() != TokKind::kRbrace); + return ParseToken(TokKind::kRbrace, + "expects '}' at the end of instruction list."); +} + +// instruction ::= name '=' shape opcode operands +bool HloParser::ParseInstruction(HloComputation::Builder* builder) { + string name; + Shape shape; + HloOpcode opcode; + std::vector operands; + if (!ParseName(&name) || + !ParseToken(TokKind::kEqual, "expects '=' in instruction") || + !ParseShape(&shape) || !ParseOpcode(&opcode)) { + return false; + } + switch (opcode) { + case HloOpcode::kParameter: { + int64 parameter_number; + return ParseToken(TokKind::kLparen, + "expects '(' before parameter number") && + ParseInt64(¶meter_number) && + ParseToken(TokKind::kRparen, + "expects ')' after parameter number") && + AddInstruction( + name, builder->AddInstruction(HloInstruction::CreateParameter( + parameter_number, shape, name))); + } + case HloOpcode::kConstant: { + std::unique_ptr literal; + return ParseToken(TokKind::kLparen, + "expects '(' before parameter number") && + ParseLiteral(&literal, shape) && + ParseToken(TokKind::kRparen, + "expects ')' after parameter number") && + AddInstruction( + name, builder->AddInstruction( + HloInstruction::CreateConstant(std::move(literal)))); + } + // Unary ops. + case HloOpcode::kAbs: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kBitcast: + case HloOpcode::kCeil: + case HloOpcode::kCopy: + case HloOpcode::kCos: + case HloOpcode::kExp: + case HloOpcode::kIsFinite: + case HloOpcode::kFloor: + case HloOpcode::kLog: + case HloOpcode::kNot: + case HloOpcode::kNegate: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kSort: + case HloOpcode::kTanh: { + return ParseOperands(&operands, /*expected_size=*/1) && + AddInstruction(name, + builder->AddInstruction(HloInstruction::CreateUnary( + shape, opcode, operands[0]))); + } + // Binary ops. + case HloOpcode::kAdd: + case HloOpcode::kDivide: + case HloOpcode::kMultiply: + case HloOpcode::kSubtract: + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kNe: + case HloOpcode::kDot: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: { + return ParseOperands(&operands, /*expected_size=*/2) && + AddInstruction( + name, builder->AddInstruction(HloInstruction::CreateBinary( + shape, opcode, operands[0], operands[1]))); + } + // Ternary ops. + case HloOpcode::kClamp: + case HloOpcode::kSelect: { + return ParseOperands(&operands, /*expected_size=*/3) && + AddInstruction( + name, + builder->AddInstruction(HloInstruction::CreateTernary( + shape, opcode, operands[0], operands[1], operands[2]))); + } + // Other supported ops. + case HloOpcode::kConvert: { + return ParseOperands(&operands, /*expected_size=*/1) && + AddInstruction( + name, builder->AddInstruction( + HloInstruction::CreateConvert(shape, operands[0]))); + } + case HloOpcode::kCrossReplicaSum: { + return ParseOperands(&operands, /*expected_size=*/1) && + AddInstruction(name, builder->AddInstruction( + HloInstruction::CreateCrossReplicaSum( + shape, operands[0]))); + } + case HloOpcode::kReshape: { + return ParseOperands(&operands, /*expected_size=*/1) && + AddInstruction( + name, builder->AddInstruction( + HloInstruction::CreateReshape(shape, operands[0]))); + } + case HloOpcode::kBroadcast: + case HloOpcode::kCall: + case HloOpcode::kCustomCall: + case HloOpcode::kConcatenate: + case HloOpcode::kReducePrecision: + case HloOpcode::kConvolution: + case HloOpcode::kGetTupleElement: + case HloOpcode::kMap: + case HloOpcode::kPad: + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kReverse: + case HloOpcode::kRng: + case HloOpcode::kSlice: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kTranspose: + case HloOpcode::kTuple: + case HloOpcode::kWhile: + case HloOpcode::kFusion: + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kBatchNormGrad: + case HloOpcode::kRecv: + case HloOpcode::kSend: + case HloOpcode::kUpdate: + case HloOpcode::kIndex: + case HloOpcode::kTrace: + return TokenError(StrCat("parsing not yet implemented for op: ", + HloOpcodeString(opcode))); + } +} + +bool HloParser::ParseLiteral(std::unique_ptr* literal, + const Shape& shape) { + switch (shape.element_type()) { + case PRED: + bool b; + if (!ParseBool(&b)) { + return false; + } + *literal = Literal::CreateR0(b); + return true; + case S32: + int64 i; + if (!ParseInt64(&i)) { + return false; + } + *literal = Literal::CreateR0(i); + return true; + case F32: + double d; + if (!ParseDecimal(&d)) { + return false; + } + *literal = Literal::CreateR0(d); + return true; + default: + return TokenError(StrCat("unsupported constant in shape: ", + ShapeUtil::HumanString(shape))); + } +} + +// operands ::= '(' operands1 ')' +// operands1 +// ::= /*empty*/ +// ::= operand (, operand)* +// operand ::= shape name +bool HloParser::ParseOperands(std::vector* operands, + const int expected_size) { + if (!ParseToken(TokKind::kLparen, + "expects '(' at the beginning of operands")) { + return false; + } + if (lexer_.GetKind() == TokKind::kRparen) { + // empty + } else { + do { + Shape shape; + string name; + if (!ParseShape(&shape) || !ParseName(&name)) { + return false; + } + HloInstruction* instruction = + tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); + if (!instruction) { + return TokenError(StrCat("instruction does not exist: ", name)); + } + operands->push_back(instruction); + } while (EatIfPresent(TokKind::kComma)); + } + if (expected_size != operands->size()) { + return TokenError(StrCat("expects ", expected_size, " operands, but has ", + operands->size(), " operands")); + } + return ParseToken(TokKind::kRparen, "expects ')' at the end of operands"); +} + +// param_list ::= '(' param_list1 ')' +// param_list1 +// ::= /*empty*/ +// ::= param (',' param)* +// param ::= name shape +bool HloParser::ParseParamList() { + if (!ParseToken(TokKind::kLparen, + "expects '(' at the beginning of param list")) { + return false; + } + + if (lexer_.GetKind() == TokKind::kRparen) { + // empty + } else { + do { + Shape shape; + if (!ParseToken(TokKind::kName, "expects name in parameter") || + !ParseShape(&shape)) { + return false; + } + } while (EatIfPresent(TokKind::kComma)); + } + return ParseToken(TokKind::kRparen, "expects ')' at the end of param list"); +} + +// shape ::= shape_val_ +// shape ::= '(' tuple_elements ')' +// tuple_elements +// ::= /*empty*/ +// ::= shape (',' shape)* +bool HloParser::ParseShape(Shape* result) { + if (EatIfPresent(TokKind::kLparen)) { // Tuple + std::vector shapes; + if (lexer_.GetKind() == TokKind::kRparen) { + /*empty*/ + } else { + // shape (',' shape)* + do { + shapes.emplace_back(); + if (!ParseShape(&shapes.back())) { + return false; + } + } while (EatIfPresent(TokKind::kComma)); + } + *result = ShapeUtil::MakeTupleShape(shapes); + return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple."); + } + + if (lexer_.GetKind() != TokKind::kShape) { + return TokenError("expects shape"); + } + *result = lexer_.GetShapeVal(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseName(string* result) { + VLOG(1) << "ParseName"; + if (lexer_.GetKind() != TokKind::kName) { + return TokenError("expects name"); + } + *result = lexer_.GetStrVal(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseOpcode(HloOpcode* result) { + VLOG(1) << "ParseOpcode"; + if (lexer_.GetKind() != TokKind::kOpcode) { + return TokenError("expects opcode"); + } + *result = lexer_.GetOpcodeVal(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseInt64(int64* result) { + VLOG(1) << "ParseInt64"; + if (lexer_.GetKind() != TokKind::kInt) { + return TokenError("expects integer"); + } + *result = lexer_.GetInt64Val(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseDecimal(double* result) { + switch (lexer_.GetKind()) { + case TokKind::kDecimal: + *result = lexer_.GetDecimalVal(); + break; + case TokKind::kInt: + *result = static_cast(lexer_.GetInt64Val()); + break; + default: + return TokenError("expects decimal or integer"); + } + lexer_.Lex(); + return true; +} + +bool HloParser::ParseBool(bool* result) { + if (lexer_.GetKind() != TokKind::kw_true && + lexer_.GetKind() != TokKind::kw_false) { + return TokenError("expects true or false"); + } + *result = lexer_.GetKind() == TokKind::kw_true; + lexer_.Lex(); + return true; +} + +bool HloParser::ParseToken(TokKind kind, const string& msg) { + if (lexer_.GetKind() != kind) { + return TokenError(msg); + } + lexer_.Lex(); + return true; +} + +bool HloParser::EatIfPresent(TokKind kind) { + if (lexer_.GetKind() != kind) { + return false; + } + lexer_.Lex(); + return true; +} + +bool HloParser::AddInstruction(const string& name, + HloInstruction* instruction) { + auto result = instruction_pool_.insert({name, instruction}); + if (!result.second) { + return TokenError(StrCat("instruction already exists: ", name)); + } + return true; +} + +} // namespace + +StatusOr> Parse(StringPiece str) { + HloParser parser(str); + if (!parser.Run()) { + return InvalidArgument("Syntax error: %s", parser.GetError().c_str()); + } + return parser.ConsumeHloModule(); +} + +} // namespace tools +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.h b/tensorflow/compiler/xla/tools/parser/hlo_parser.h new file mode 100644 index 00000000000..9aaf18ef20d --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.h @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace tools { + +// The api of the hlo parser. Given a string in the HloModule::ToString() +// format, returns the parsed HloModule. +StatusOr> Parse(tensorflow::StringPiece str); + +} // namespace tools +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc new file mode 100644 index 00000000000..4ecece3eac1 --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -0,0 +1,240 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +#include +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace tools { +namespace { + +struct TestData { + string test_name; + string module_string; +}; + +string TestDataToString(const ::testing::TestParamInfo& data) { + return data.param.test_name; +} + +std::vector CreateTestCases() { + // clang-format off + return std::vector({ +// ax + y +{ +"AxpyParam", +R"(HloModule axpy_module: + +ENTRY %axpy.v5 (alpha: f32[2,4], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[2,4]{1,0} parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %alpha, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} + +)" +}, +// pred constant +{ +"ConstantPred", +R"(HloModule constant_pred_module: + +ENTRY %constant_pred () -> pred[] { + %constant = pred[] constant(true) +} + +)" +}, +// s32 constant +{ +"ConstantS32", +R"(HloModule constant_s32_module: + +ENTRY %constant_s32 () -> s32[] { + %constant = s32[] constant(-42) +} + +)" +}, +// f32 constant, but the value is not a decimal +{ +"ConstantF32", R"(HloModule ConstantF32_module: + +ENTRY %ConstantF32.v4 () -> f32[] { + %constant = f32[] constant(42) +} + +)" +}, +// constant + constant +{ +"AddConstants", +R"(HloModule add_constants_module: + +ENTRY %add_constants () -> f32[] { + %constant = f32[] constant(3.14) + %add = f32[] add(f32[] %constant, f32[] %constant) +} + +)" +}, +// v1 > v2 ? v1 : v2 +{ +"SelectR1F32", +R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module: + +ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] { + %v1 = f32[4]{0} parameter(0) + %v2 = f32[4]{0} parameter(1) + %greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2) + %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2) +} + +)" +} + }); + // clang-format on +} + +class HloParserTest : public ::testing::Test, + public ::testing::WithParamInterface { + protected: + void ExpectSuccess() { + const string& original = GetParam().module_string; + auto result = Parse(original); + TF_EXPECT_OK(result.status()); + EXPECT_EQ(original, result.ValueOrDie()->ToString()); + } +}; + +TEST_P(HloParserTest, Run) { ExpectSuccess(); } + +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest, + ::testing::ValuesIn(CreateTestCases()), + TestDataToString); + +TEST_F(HloParserTest, Empty) { + const string original = ""; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, Garbage) { + const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, WrongOpcode) { + const string original = R"(HloModule wrong_opcode: + +ENTRY %blabla (x: f32[], y: f32[]) -> f32[] { + %x = f32[]{} parameter(0) + %y = f32[]{} parameter(1) + %le = pred[]{} le(f32[]{} %x, f32[]{} %y) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, WrongShape) { + const string original = R"(HloModule wrong_opcode: + +ENTRY %blabla (x: g32[]) -> g32[] { + %x = g32[]{} parameter(0) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, WrongOperandsSize) { + const string original = R"(HloModule wrong_opcode: + +ENTRY %blabla (x: f32[]) -> pred[] { + %x = f32[]{} parameter(0) + %eq = pred[]{} equal-to(f32[]{} %x) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, OperandNotFound) { + const string original = R"(HloModule operand_not_found: +ENTRY %blabla (x: f32[]) -> pred[] { + %x = f32[]{} parameter(0) + %eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y) +} +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, MoreConstants) { + const string original = R"(HloModule SelectScalarS32True_module: + +ENTRY %SelectScalarS32True.v4 () -> s32[] { + %constant.2 = pred[] constant(true) + %constant.1 = s32[] constant(-42) + %constant = s32[] constant(42) + %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant) +} + +)"; + auto result = Parse(original); + TF_EXPECT_OK(result.status()); + // Constant instructions have no name. The string will be parsed successfully + // but the constant names will not be exactly the same. +} + +TEST_F(HloParserTest, ConstantWithExp) { + const string original = R"(HloModule ConstantWithExp_module: + +ENTRY %ConstantWithExp.v4 () -> f32[] { + %constant.1 = f32[] constant(3e+2) +} + +)"; + auto result = Parse(original); + TF_EXPECT_OK(result.status()); + // The string will be parsed successfully but the output strings are not + // exactly the same, because "3e2" is parsed into value 300 and will be + // printed as "300". +} + +TEST_F(HloParserTest, Tuple) { + const string original = R"(HloModule EmptyTupleCreate_module: + +ENTRY %EmptyTupleCreate.v1 () -> () { + %tuple = () tuple() +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +} // namespace +} // namespace tools +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/tools/parser/hlo_token.h new file mode 100644 index 00000000000..1f75e17c7f0 --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_token.h @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ + +namespace xla { +namespace tools { + +// Defines different kinds of tokens in a hlo module string. +enum class TokKind { + // Markers + kEof, + kError, + + // Tokens with no info. + kEqual, // = + kComma, // , + kColon, // : + kLsquare, + kRsquare, // [ ] + kLbrace, + kRbrace, // { } + kLparen, + kRparen, // ( ) + + kArrow, // -> + + // Keywords + kw_HloModule, + kw_ENTRY, + kw_true, + kw_false, + + // Typed tokens. + kName, // %foo + kShape, // f32[2,3]{1,0} + kOpcode, // add + kInt, // 42 + kDecimal, // 4.2 +}; + +} // namespace tools +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ From 2cd178ef5a4e5cac27b55729f0203c4864540063 Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Thu, 19 Oct 2017 15:22:08 -0700 Subject: [PATCH 04/41] [XLA] Teach transpose folding how to transpose the LHS of convolutions This is now possible now that we have added the required fields to ConvolutionDimensionNumbers. PiperOrigin-RevId: 172807540 --- .../compiler/xla/service/transpose_folding.cc | 105 ++++++++++++------ .../xla/service/transpose_folding_test.cc | 28 +++-- 2 files changed, 89 insertions(+), 44 deletions(-) diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 816c8a7485b..8c2640adf52 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -58,14 +58,32 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution( return {}; } - // We only support folding the RHS. - const int64 kRhsOperandIndex = 1; - auto& operand = *convolution.operand(kRhsOperandIndex); - if (operand.opcode() == HloOpcode::kTranspose && operand.user_count() == 1) { - return transposable_conv_operands(convolution, {kRhsOperandIndex}); + const ConvolutionDimensionNumbers& dnums = + convolution.convolution_dimension_numbers(); + + TransposeFolding::OperandIndices operand_set; + for (int64 i = 0; i < convolution.operand_count(); ++i) { + auto& operand = *convolution.operand(i); + if (operand.opcode() == HloOpcode::kTranspose && + operand.user_count() == 1) { + const auto& transpose_dimensions = operand.dimensions(); + // We can transpose the LHS so long as it doesn't move around spatial + // dimensions because ConvolutionDimensionNumbers doesn't have different + // fields for input and output spatial dimensions. + if (i == 0 && + std::any_of(dnums.spatial_dimensions().begin(), + dnums.spatial_dimensions().end(), + [&](const int64 spatial_dimension) { + return transpose_dimensions[spatial_dimension] != + spatial_dimension; + })) { + continue; + } + operand_set.push_back(i); + } } - return {}; + return transposable_conv_operands(convolution, operand_set); } using InstructionOperandsPair = @@ -98,40 +116,61 @@ bool FoldTransposeIntoDot(InstructionOperandsPair pair) { // Returns whether the module is changed. bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto& convolution = *pair.first; - - // We only support fusing the RHS transpose into convolution. - // - // ConvolutionDimensionNumbers doesn't make enough of a distinction between - // the output and the activations. - // - // TODO(b/37125184): Support transposing the LHS too. - if (pair.second.size() != 1 || pair.second.front() != 1) { - return false; - } + auto& operand_indices = pair.second; const ConvolutionDimensionNumbers& dnums = convolution.convolution_dimension_numbers(); - HloInstruction& transpose = *convolution.mutable_operand(1); - CHECK_EQ(transpose.opcode(), HloOpcode::kTranspose); - const auto& transpose_dimensions = transpose.dimensions(); - HloInstruction& transpose_operand = *transpose.mutable_operand(0); - - // Everything remains the same except for the kernel dimension numbers. We - // need to apply the transpose permutation to the original shape to figure out - // what the new logical dimensions are. ConvolutionDimensionNumbers new_dnums = dnums; - new_dnums.set_kernel_input_feature_dimension( - transpose_dimensions[dnums.kernel_input_feature_dimension()]); - new_dnums.set_kernel_output_feature_dimension( - transpose_dimensions[dnums.kernel_output_feature_dimension()]); - for (auto& kernel_spatial_dimension : - *new_dnums.mutable_kernel_spatial_dimensions()) { - kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension]; + + HloInstruction* new_lhs; + const int64 kLhsIdx = 0; + if (std::find(operand_indices.begin(), operand_indices.end(), kLhsIdx) != + operand_indices.end()) { + HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx); + const auto& transpose_dimensions = transpose.dimensions(); + HloInstruction& transpose_operand = *transpose.mutable_operand(0); + + // Everything remains the same except for the input/output dimension + // numbers. We need to apply the transpose permutation to the original shape + // to figure out what the new logical dimensions are. + new_dnums.set_input_batch_dimension( + transpose_dimensions[dnums.input_batch_dimension()]); + new_dnums.set_input_feature_dimension( + transpose_dimensions[dnums.input_feature_dimension()]); + for (const auto& spatial_dimension : dnums.spatial_dimensions()) { + CHECK_EQ(spatial_dimension, transpose_dimensions[spatial_dimension]); + } + new_lhs = &transpose_operand; + } else { + new_lhs = convolution.mutable_operand(kLhsIdx); + } + + HloInstruction* new_rhs; + const int64 kRhsIdx = 1; + if (std::find(operand_indices.begin(), operand_indices.end(), kRhsIdx) != + operand_indices.end()) { + HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx); + const auto& transpose_dimensions = transpose.dimensions(); + HloInstruction& transpose_operand = *transpose.mutable_operand(0); + + // Everything remains the same except for the kernel dimension numbers. We + // need to apply the transpose permutation to the original shape to figure + // out what the new logical dimensions are. + new_dnums.set_kernel_input_feature_dimension( + transpose_dimensions[dnums.kernel_input_feature_dimension()]); + new_dnums.set_kernel_output_feature_dimension( + transpose_dimensions[dnums.kernel_output_feature_dimension()]); + for (auto& kernel_spatial_dimension : + *new_dnums.mutable_kernel_spatial_dimensions()) { + kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension]; + } + new_rhs = &transpose_operand; + } else { + new_rhs = convolution.mutable_operand(kRhsIdx); } auto new_conv = HloInstruction::CreateConvolve( - convolution.shape(), convolution.mutable_operand(0), &transpose_operand, - convolution.window(), new_dnums); + convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums); TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( &convolution, std::move(new_conv))); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index a6161b46460..00462f9be1e 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -313,8 +313,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(1)); } -// Test that a transpose of the activations does not get folded into -// convolution. +// Test that a transpose of the activations gets folded into convolution. TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { auto builder = HloComputation::Builder("entry_computation"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( @@ -348,18 +347,25 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { module.AddEntryComputation(builder.Build(conv)); FoldTranspose(&module); - // Instructions after folding: transpose_x, y, and the convolution. + // Instructions after folding: x, y, and the convolution. std::unordered_set instruction_set( entry_computation->instructions().begin(), entry_computation->instructions().end()); - CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(transpose_x)) - << "transpose_x is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(conv)) - << "transpose_x is not in entry_computation."; - CHECK_EQ(0, instruction_set.size()) - << "entry_computation should contain exactly 4 instructions."; + EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + EXPECT_EQ(1, instruction_set.size()) + << "entry_computation should contain exactly 3 instructions."; + HloInstruction* new_conv = *instruction_set.begin(); + EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode()); + EXPECT_EQ(dnums.input_feature_dimension(), + new_conv->convolution_dimension_numbers().input_batch_dimension()); + EXPECT_EQ( + dnums.input_batch_dimension(), + new_conv->convolution_dimension_numbers().input_feature_dimension()); + EXPECT_EQ(dnums.spatial_dimensions(0), + new_conv->convolution_dimension_numbers().spatial_dimensions(0)); + EXPECT_EQ(dnums.spatial_dimensions(1), + new_conv->convolution_dimension_numbers().spatial_dimensions(1)); } } // namespace From f080052284a4a39113051fb1178d91365e9872a8 Mon Sep 17 00:00:00 2001 From: Igor Saprykin Date: Thu, 19 Oct 2017 15:27:52 -0700 Subject: [PATCH 05/41] Move text_classification_character_rnn from .contrib utils to .core utils. Also removes sklearn comparison. PiperOrigin-RevId: 172808535 --- .../text_classification_character_rnn.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/tensorflow/examples/learn/text_classification_character_rnn.py b/tensorflow/examples/learn/text_classification_character_rnn.py index 1fc9388a1a0..86adc056add 100644 --- a/tensorflow/examples/learn/text_classification_character_rnn.py +++ b/tensorflow/examples/learn/text_classification_character_rnn.py @@ -30,7 +30,6 @@ import sys import numpy as np import pandas -from sklearn import metrics import tensorflow as tf FLAGS = None @@ -46,8 +45,8 @@ def char_rnn_model(features, labels, mode): byte_vectors = tf.one_hot(features[CHARS_FEATURE], 256, 1., 0.) byte_list = tf.unstack(byte_vectors, axis=1) - cell = tf.contrib.rnn.GRUCell(HIDDEN_SIZE) - _, encoding = tf.contrib.rnn.static_rnn(cell, byte_list, dtype=tf.float32) + cell = tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE) + _, encoding = tf.nn.static_rnn(cell, byte_list, dtype=tf.float32) logits = tf.layers.dense(encoding, MAX_LABEL, activation=None) @@ -98,28 +97,20 @@ def main(unused_argv): train_input_fn = tf.estimator.inputs.numpy_input_fn( x={CHARS_FEATURE: x_train}, y=y_train, - batch_size=len(x_train), + batch_size=128, num_epochs=None, shuffle=True) classifier.train(input_fn=train_input_fn, steps=100) - # Predict. + # Eval. test_input_fn = tf.estimator.inputs.numpy_input_fn( x={CHARS_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False) - predictions = classifier.predict(input_fn=test_input_fn) - y_predicted = np.array(list(p['class'] for p in predictions)) - y_predicted = y_predicted.reshape(np.array(y_test).shape) - # Score with sklearn. - score = metrics.accuracy_score(y_test, y_predicted) - print('Accuracy (sklearn): {0:f}'.format(score)) - - # Score with tensorflow. scores = classifier.evaluate(input_fn=test_input_fn) - print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy'])) + print('Accuracy: {0:f}'.format(scores['accuracy'])) if __name__ == '__main__': From bc93dcbd9f7b445c5f6f0d1c8f597324d412a76a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Oct 2017 16:00:31 -0700 Subject: [PATCH 06/41] Fix precision/recall test. Precision and Recall have as the numerator TP: true positives. The labels generated in the test were only negative, and hence the test passed before because all updates were 0. PiperOrigin-RevId: 172812994 --- .../metrics/python/ops/metric_ops_test.py | 58 +++++++++---------- 1 file changed, 26 insertions(+), 32 deletions(-) diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index cc0ad155fa0..f288fceef6c 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -1101,7 +1101,7 @@ class StreamingPrecisionTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) precision, update_op = metrics.streaming_precision(predictions, labels) with self.test_session() as sess: @@ -1265,7 +1265,7 @@ class StreamingRecallTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) recall, update_op = metrics.streaming_recall(predictions, labels) with self.test_session() as sess: @@ -1388,7 +1388,7 @@ class StreamingFPRTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) fpr, update_op = metrics.streaming_false_positive_rate( predictions, labels) @@ -1516,7 +1516,7 @@ class StreamingFNRTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) fnr, update_op = metrics.streaming_false_negative_rate( predictions, labels) @@ -1737,7 +1737,7 @@ class StreamingAUCTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) auc, update_op = metrics.streaming_auc(predictions, labels) with self.test_session() as sess: @@ -2009,7 +2009,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) specificity, update_op = metrics.streaming_specificity_at_sensitivity( predictions, labels, sensitivity=0.7) @@ -2271,7 +2271,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) thresholds = [0, 0.5, 1.0] prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, labels, @@ -2282,12 +2282,14 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): with self.test_session() as sess: sess.run(variables.local_variables_initializer()) - # Run several updates, then verify idempotency. - sess.run([prec_op, rec_op]) + # Run several updates. + for _ in range(10): + sess.run([prec_op, rec_op]) + + # Then verify idempotency. initial_prec = prec.eval() initial_rec = rec.eval() for _ in range(10): - sess.run([prec_op, rec_op]) self.assertAllClose(initial_prec, prec.eval()) self.assertAllClose(initial_rec, rec.eval()) @@ -2361,14 +2363,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): rec, rec_op = metrics.streaming_recall_at_thresholds( predictions, labels, thresholds, weights=weights) - [prec_low, prec_high] = array_ops.split( - value=prec, num_or_size_splits=2, axis=0) - prec_low = array_ops.reshape(prec_low, shape=()) - prec_high = array_ops.reshape(prec_high, shape=()) - [rec_low, rec_high] = array_ops.split( - value=rec, num_or_size_splits=2, axis=0) - rec_low = array_ops.reshape(rec_low, shape=()) - rec_high = array_ops.reshape(rec_high, shape=()) + prec_low = prec[0] + prec_high = prec[1] + rec_low = rec[0] + rec_high = rec[1] sess.run(variables.local_variables_initializer()) sess.run([prec_op, rec_op]) @@ -2391,14 +2389,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): rec, rec_op = metrics.streaming_recall_at_thresholds( predictions, labels, thresholds, weights=weights) - [prec_low, prec_high] = array_ops.split( - value=prec, num_or_size_splits=2, axis=0) - prec_low = array_ops.reshape(prec_low, shape=()) - prec_high = array_ops.reshape(prec_high, shape=()) - [rec_low, rec_high] = array_ops.split( - value=rec, num_or_size_splits=2, axis=0) - rec_low = array_ops.reshape(rec_low, shape=()) - rec_high = array_ops.reshape(rec_high, shape=()) + prec_low = prec[0] + prec_high = prec[1] + rec_low = rec[0] + rec_high = rec[1] sess.run(variables.local_variables_initializer()) sess.run([prec_op, rec_op]) @@ -2420,10 +2414,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, thresholds) - [prec_low, prec_high] = array_ops.split( - value=prec, num_or_size_splits=2, axis=0) - [rec_low, rec_high] = array_ops.split( - value=rec, num_or_size_splits=2, axis=0) + prec_low = prec[0] + prec_high = prec[1] + rec_low = rec[0] + rec_high = rec[1] sess.run(variables.local_variables_initializer()) sess.run([prec_op, rec_op]) @@ -2562,7 +2556,7 @@ class StreamingFPRThresholdsTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) thresholds = [0, 0.5, 1.0] fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( predictions, labels, thresholds) @@ -2794,7 +2788,7 @@ class StreamingFNRThresholdsTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) thresholds = [0, 0.5, 1.0] fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( predictions, labels, thresholds) From 7a253f3da99c3692d464a8dd95d8280d4cd8973a Mon Sep 17 00:00:00 2001 From: Igor Saprykin Date: Thu, 19 Oct 2017 16:16:29 -0700 Subject: [PATCH 07/41] Fix random_forest_mnist.py and eliminate a contrib.learn reference to skcompat. PiperOrigin-RevId: 172815173 --- .../examples/learn/random_forest_mnist.py | 63 ++++++++++--------- 1 file changed, 35 insertions(+), 28 deletions(-) diff --git a/tensorflow/examples/learn/random_forest_mnist.py b/tensorflow/examples/learn/random_forest_mnist.py index 3c09990ea1e..72c935cdae2 100644 --- a/tensorflow/examples/learn/random_forest_mnist.py +++ b/tensorflow/examples/learn/random_forest_mnist.py @@ -1,4 +1,4 @@ - # Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,18 +21,14 @@ import argparse import sys import tempfile -# pylint: disable=g-backslash-continuation -from tensorflow.contrib.learn.python.learn\ - import metric_spec -from tensorflow.contrib.learn.python.learn.estimators\ - import estimator -from tensorflow.contrib.tensor_forest.client\ - import eval_metrics -from tensorflow.contrib.tensor_forest.client\ - import random_forest -from tensorflow.contrib.tensor_forest.python\ - import tensor_forest +import numpy + +from tensorflow.contrib.learn.python.learn import metric_spec +from tensorflow.contrib.tensor_forest.client import eval_metrics +from tensorflow.contrib.tensor_forest.client import random_forest +from tensorflow.contrib.tensor_forest.python import tensor_forest from tensorflow.examples.tutorials.mnist import input_data +from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.platform import app FLAGS = None @@ -41,16 +37,15 @@ FLAGS = None def build_estimator(model_dir): """Build an estimator.""" params = tensor_forest.ForestHParams( - num_classes=10, num_features=784, - num_trees=FLAGS.num_trees, max_nodes=FLAGS.max_nodes) + num_classes=10, + num_features=784, + num_trees=FLAGS.num_trees, + max_nodes=FLAGS.max_nodes) graph_builder_class = tensor_forest.RandomForestGraphs if FLAGS.use_training_loss: graph_builder_class = tensor_forest.TrainingLossForest - # Use the SKCompat wrapper, which gives us a convenient way to split - # in-memory data like MNIST into batches. - return estimator.SKCompat(random_forest.TensorForestEstimator( - params, graph_builder_class=graph_builder_class, - model_dir=model_dir)) + return random_forest.TensorForestEstimator( + params, graph_builder_class=graph_builder_class, model_dir=model_dir) def train_and_eval(): @@ -62,18 +57,30 @@ def train_and_eval(): mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=False) - est.fit(x=mnist.train.images, y=mnist.train.labels, - batch_size=FLAGS.batch_size) + train_input_fn = numpy_io.numpy_input_fn( + x={'images': mnist.train.images}, + y=mnist.train.labels.astype(numpy.int32), + batch_size=FLAGS.batch_size, + num_epochs=None, + shuffle=True) + est.fit(input_fn=train_input_fn, steps=None) metric_name = 'accuracy' - metric = {metric_name: - metric_spec.MetricSpec( - eval_metrics.get_metric(metric_name), - prediction_key=eval_metrics.get_prediction_key(metric_name))} + metric = { + metric_name: + metric_spec.MetricSpec( + eval_metrics.get_metric(metric_name), + prediction_key=eval_metrics.get_prediction_key(metric_name)) + } - results = est.score(x=mnist.test.images, y=mnist.test.labels, - batch_size=FLAGS.batch_size, - metrics=metric) + test_input_fn = numpy_io.numpy_input_fn( + x={'images': mnist.test.images}, + y=mnist.test.labels.astype(numpy.int32), + num_epochs=1, + batch_size=FLAGS.batch_size, + shuffle=False) + + results = est.evaluate(input_fn=test_input_fn, metrics=metric) for key in sorted(results): print('%s: %s' % (key, results[key])) From 60a03dfc7dbde7acf58ffaeef897eb3ebb98603f Mon Sep 17 00:00:00 2001 From: Michael Case Date: Thu, 19 Oct 2017 16:18:46 -0700 Subject: [PATCH 08/41] Move s3 file system support from contrib/ to core/platform/. PiperOrigin-RevId: 172815422 --- tensorflow/BUILD | 2 +- tensorflow/contrib/makefile/Makefile | 1 + tensorflow/core/platform/default/build_config.bzl | 2 +- tensorflow/{contrib => core/platform}/s3/BUILD | 0 tensorflow/{contrib => core/platform}/s3/s3_crypto.cc | 2 +- tensorflow/{contrib => core/platform}/s3/s3_crypto.h | 0 tensorflow/{contrib => core/platform}/s3/s3_file_system.cc | 4 ++-- tensorflow/{contrib => core/platform}/s3/s3_file_system.h | 0 .../{contrib => core/platform}/s3/s3_file_system_test.cc | 2 +- 9 files changed, 7 insertions(+), 6 deletions(-) rename tensorflow/{contrib => core/platform}/s3/BUILD (100%) rename tensorflow/{contrib => core/platform}/s3/s3_crypto.cc (98%) rename tensorflow/{contrib => core/platform}/s3/s3_crypto.h (100%) rename tensorflow/{contrib => core/platform}/s3/s3_file_system.cc (99%) rename tensorflow/{contrib => core/platform}/s3/s3_file_system.h (100%) rename tensorflow/{contrib => core/platform}/s3/s3_file_system_test.cc (99%) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index d5c56cdc184..d7d6d5fc77d 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -414,7 +414,6 @@ filegroup( "//tensorflow/contrib/remote_fused_graph/pylib:all_files", "//tensorflow/contrib/resampler:all_files", "//tensorflow/contrib/rnn:all_files", - "//tensorflow/contrib/s3:all_files", "//tensorflow/contrib/saved_model:all_files", "//tensorflow/contrib/saved_model/cc/saved_model:all_files", "//tensorflow/contrib/seq2seq:all_files", @@ -468,6 +467,7 @@ filegroup( "//tensorflow/core/platform/cloud:all_files", "//tensorflow/core/platform/default/build_config:all_files", "//tensorflow/core/platform/hadoop:all_files", + "//tensorflow/core/platform/s3:all_files", "//tensorflow/core/profiler:all_files", "//tensorflow/core/profiler/internal:all_files", "//tensorflow/core/profiler/internal/advisor:all_files", diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index be7c790ee9e..3dcff3d4a3d 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -502,6 +502,7 @@ $(wildcard tensorflow/core/platform/google/*) \ $(wildcard tensorflow/core/platform/google/*/*) \ $(wildcard tensorflow/core/platform/jpeg.*) \ $(wildcard tensorflow/core/platform/png.*) \ +$(wildcard tensorflow/core/platform/s3/*) \ $(wildcard tensorflow/core/platform/stream_executor.*) \ $(wildcard tensorflow/core/platform/windows/*) \ $(wildcard tensorflow/core/user_ops/*.cu.cc) \ diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 2c14ea917c0..e4518a8e2fd 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -467,7 +467,7 @@ def tf_additional_core_deps(): "//conditions:default": [], }) + select({ "//tensorflow:with_s3_support": [ - "//tensorflow/contrib/s3:s3_file_system", + "//tensorflow/core/platform/s3:s3_file_system", ], "//conditions:default": [], }) diff --git a/tensorflow/contrib/s3/BUILD b/tensorflow/core/platform/s3/BUILD similarity index 100% rename from tensorflow/contrib/s3/BUILD rename to tensorflow/core/platform/s3/BUILD diff --git a/tensorflow/contrib/s3/s3_crypto.cc b/tensorflow/core/platform/s3/s3_crypto.cc similarity index 98% rename from tensorflow/contrib/s3/s3_crypto.cc rename to tensorflow/core/platform/s3/s3_crypto.cc index 1450384dc0f..14bbed19a50 100644 --- a/tensorflow/contrib/s3/s3_crypto.cc +++ b/tensorflow/core/platform/s3/s3_crypto.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/s3/s3_crypto.h" +#include "tensorflow/core/platform/s3/s3_crypto.h" #include #include diff --git a/tensorflow/contrib/s3/s3_crypto.h b/tensorflow/core/platform/s3/s3_crypto.h similarity index 100% rename from tensorflow/contrib/s3/s3_crypto.h rename to tensorflow/core/platform/s3/s3_crypto.h diff --git a/tensorflow/contrib/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc similarity index 99% rename from tensorflow/contrib/s3/s3_file_system.cc rename to tensorflow/core/platform/s3/s3_file_system.cc index daced831453..51c85592bf4 100644 --- a/tensorflow/contrib/s3/s3_file_system.cc +++ b/tensorflow/core/platform/s3/s3_file_system.cc @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/s3/s3_file_system.h" -#include "tensorflow/contrib/s3/s3_crypto.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/s3/s3_file_system.h" +#include "tensorflow/core/platform/s3/s3_crypto.h" #include #include diff --git a/tensorflow/contrib/s3/s3_file_system.h b/tensorflow/core/platform/s3/s3_file_system.h similarity index 100% rename from tensorflow/contrib/s3/s3_file_system.h rename to tensorflow/core/platform/s3/s3_file_system.h diff --git a/tensorflow/contrib/s3/s3_file_system_test.cc b/tensorflow/core/platform/s3/s3_file_system_test.cc similarity index 99% rename from tensorflow/contrib/s3/s3_file_system_test.cc rename to tensorflow/core/platform/s3/s3_file_system_test.cc index 949281fad4a..0b42f5fcec0 100644 --- a/tensorflow/contrib/s3/s3_file_system_test.cc +++ b/tensorflow/core/platform/s3/s3_file_system_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/s3/s3_file_system.h" +#include "tensorflow/core/platform/s3/s3_file_system.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" From d88cccebc7f61078d775d26f4714a06bc4002fcf Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Thu, 19 Oct 2017 16:20:06 -0700 Subject: [PATCH 09/41] Rename SNAPPY to TF_USE_SNAPPY This way there's less risk of it conflicting with downstream BUILD rules. PiperOrigin-RevId: 172815580 --- tensorflow/contrib/cmake/external/snappy.cmake | 2 +- tensorflow/core/BUILD | 2 +- tensorflow/core/platform/posix/port.cc | 8 ++++---- tensorflow/core/platform/windows/port.cc | 8 ++++---- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tensorflow/contrib/cmake/external/snappy.cmake b/tensorflow/contrib/cmake/external/snappy.cmake index a35d8654fb6..2d2451521c0 100644 --- a/tensorflow/contrib/cmake/external/snappy.cmake +++ b/tensorflow/contrib/cmake/external/snappy.cmake @@ -47,4 +47,4 @@ ExternalProject_Add(snappy ) # actually enables snappy in the source code -add_definitions(-DSNAPPY) \ No newline at end of file +add_definitions(-DTF_USE_SNAPPY) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 5ab84fec5bf..d198a796a7c 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1410,7 +1410,7 @@ cc_library( hdrs = LIB_INTERNAL_PUBLIC_HEADERS, copts = tf_copts(), defines = tf_additional_lib_defines() + [ - "SNAPPY", + "TF_USE_SNAPPY", ] + tf_additional_verbs_lib_defines() + tf_additional_mpi_lib_defines() + tf_additional_gdr_lib_defines(), diff --git a/tensorflow/core/platform/posix/port.cc b/tensorflow/core/platform/posix/port.cc index 3b17bac8089..93a59348c8a 100644 --- a/tensorflow/core/platform/posix/port.cc +++ b/tensorflow/core/platform/posix/port.cc @@ -29,7 +29,7 @@ limitations under the License. #include #include #include -#ifdef SNAPPY +#ifdef TF_USE_SNAPPY #include "snappy.h" #endif #if (defined(__APPLE__) && defined(__MACH__)) || defined(__FreeBSD__) @@ -126,7 +126,7 @@ void AdjustFilenameForLogging(string* filename) { } bool Snappy_Compress(const char* input, size_t length, string* output) { -#ifdef SNAPPY +#ifdef TF_USE_SNAPPY output->resize(snappy::MaxCompressedLength(length)); size_t outlen; snappy::RawCompress(input, length, &(*output)[0], &outlen); @@ -139,7 +139,7 @@ bool Snappy_Compress(const char* input, size_t length, string* output) { bool Snappy_GetUncompressedLength(const char* input, size_t length, size_t* result) { -#ifdef SNAPPY +#ifdef TF_USE_SNAPPY return snappy::GetUncompressedLength(input, length, result); #else return false; @@ -147,7 +147,7 @@ bool Snappy_GetUncompressedLength(const char* input, size_t length, } bool Snappy_Uncompress(const char* input, size_t length, char* output) { -#ifdef SNAPPY +#ifdef TF_USE_SNAPPY return snappy::RawUncompress(input, length, output); #else return false; diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc index 85b53e07c43..e327d53949c 100644 --- a/tensorflow/core/platform/windows/port.cc +++ b/tensorflow/core/platform/windows/port.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include #include -#ifdef SNAPPY +#ifdef TF_USE_SNAPPY #include "snappy.h" #endif @@ -118,7 +118,7 @@ void AdjustFilenameForLogging(string* filename) { } bool Snappy_Compress(const char* input, size_t length, string* output) { -#ifdef SNAPPY +#ifdef TF_USE_SNAPPY output->resize(snappy::MaxCompressedLength(length)); size_t outlen; snappy::RawCompress(input, length, &(*output)[0], &outlen); @@ -131,7 +131,7 @@ bool Snappy_Compress(const char* input, size_t length, string* output) { bool Snappy_GetUncompressedLength(const char* input, size_t length, size_t* result) { -#ifdef SNAPPY +#ifdef TF_USE_SNAPPY return snappy::GetUncompressedLength(input, length, result); #else return false; @@ -139,7 +139,7 @@ bool Snappy_GetUncompressedLength(const char* input, size_t length, } bool Snappy_Uncompress(const char* input, size_t length, char* output) { -#ifdef SNAPPY +#ifdef TF_USE_SNAPPY return snappy::RawUncompress(input, length, output); #else return false; From f2250bfe85b59c9fba128aad9993417eca711d75 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Oct 2017 16:24:56 -0700 Subject: [PATCH 10/41] Replace http://mirror.bazel.build with https://mirror.bazel.build PiperOrigin-RevId: 172816169 --- WORKSPACE | 2 +- tensorflow/contrib/cmake/external/cub.cmake | 2 +- tensorflow/contrib/cmake/external/gif.cmake | 2 +- tensorflow/contrib/cmake/external/jpeg.cmake | 2 +- tensorflow/contrib/cmake/external/lmdb.cmake | 2 +- .../contrib/makefile/download_dependencies.sh | 8 +- tensorflow/workspace.bzl | 90 +++++++++---------- 7 files changed, 54 insertions(+), 54 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 1bf1069f880..b40913801ba 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -5,7 +5,7 @@ http_archive( sha256 = "110fe68753413777944b473c25eed6368c4a0487cee23a7bac1b13cc49d3e257", strip_prefix = "rules_closure-4af89ef1db659eb41f110df189b67d4cf14073e1", urls = [ - "http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", "https://github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", # 2017-08-28 ], ) diff --git a/tensorflow/contrib/cmake/external/cub.cmake b/tensorflow/contrib/cmake/external/cub.cmake index d98579d2077..e03026b1b0a 100644 --- a/tensorflow/contrib/cmake/external/cub.cmake +++ b/tensorflow/contrib/cmake/external/cub.cmake @@ -14,7 +14,7 @@ # ============================================================================== include (ExternalProject) -set(cub_URL http://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.3.zip) +set(cub_URL https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.3.zip) set(cub_HASH SHA256=b7ead9e291d34ffa8074243541c1380d63be63f88de23de8ee548db573b72ebe) set(cub_BUILD ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) set(cub_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) diff --git a/tensorflow/contrib/cmake/external/gif.cmake b/tensorflow/contrib/cmake/external/gif.cmake index 5cb719b8787..3d53c51fffc 100644 --- a/tensorflow/contrib/cmake/external/gif.cmake +++ b/tensorflow/contrib/cmake/external/gif.cmake @@ -15,7 +15,7 @@ include (ExternalProject) set(gif_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/gif_archive/giflib-5.1.4/) -set(gif_URL http://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz) +set(gif_URL https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz) set(gif_HASH SHA256=34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1) set(gif_INSTALL ${CMAKE_BINARY_DIR}/gif/install) set(gif_BUILD ${CMAKE_BINARY_DIR}/gif/src/gif) diff --git a/tensorflow/contrib/cmake/external/jpeg.cmake b/tensorflow/contrib/cmake/external/jpeg.cmake index 058f554b8f2..d9a165e856c 100644 --- a/tensorflow/contrib/cmake/external/jpeg.cmake +++ b/tensorflow/contrib/cmake/external/jpeg.cmake @@ -15,7 +15,7 @@ include (ExternalProject) set(jpeg_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/jpeg_archive) -set(jpeg_URL http://mirror.bazel.build/www.ijg.org/files/jpegsrc.v9a.tar.gz) +set(jpeg_URL https://mirror.bazel.build/www.ijg.org/files/jpegsrc.v9a.tar.gz) set(jpeg_HASH SHA256=3a753ea48d917945dd54a2d97de388aa06ca2eb1066cbfdc6652036349fe05a7) set(jpeg_BUILD ${CMAKE_CURRENT_BINARY_DIR}/jpeg/src/jpeg) set(jpeg_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/jpeg/install) diff --git a/tensorflow/contrib/cmake/external/lmdb.cmake b/tensorflow/contrib/cmake/external/lmdb.cmake index 28ec833babe..79971b7cfc3 100644 --- a/tensorflow/contrib/cmake/external/lmdb.cmake +++ b/tensorflow/contrib/cmake/external/lmdb.cmake @@ -15,7 +15,7 @@ include (ExternalProject) set(lmdb_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/lmdb) -set(lmdb_URL http://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz) +set(lmdb_URL https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz) set(lmdb_HASH SHA256=108532fb94c6f227558d45be3f3347b52539f0f58290a7bb31ec06c462d05326) set(lmdb_BUILD ${CMAKE_BINARY_DIR}/lmdb/src/lmdb) set(lmdb_INSTALL ${CMAKE_BINARY_DIR}/lmdb/install) diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index 39c89628d96..f0b9658e3d1 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -20,11 +20,11 @@ DOWNLOADS_DIR=tensorflow/contrib/makefile/downloads BZL_FILE_PATH=tensorflow/workspace.bzl EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" -GEMMLOWP_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" +GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" -NSYNC_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" -PROTOBUF_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" -RE2_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" +NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" +PROTOBUF_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" +RE2_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" # TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 54559edbea2..a863aa18dd0 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -157,7 +157,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): mkl_repository( name = "mkl", urls = [ - "http://mirror.bazel.build/github.com/01org/mkl-dnn/releases/download/v0.9/mklml_lnx_2018.0.20170720.tgz", + "https://mirror.bazel.build/github.com/01org/mkl-dnn/releases/download/v0.9/mklml_lnx_2018.0.20170720.tgz", # "https://github.com/01org/mkl-dnn/releases/download/v0.9/mklml_lnx_2018.0.20170720.tgz", ], sha256 = "57ba56c4c243f403ff78f417ff854ef50b9eddf4a610a917b7c95e7fa8553a4b", @@ -174,7 +174,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "mkl_dnn", urls = [ "https://github.com/01org/mkl-dnn/archive/b01e3a55a07be62172e713bcd2644c5176360212.tar.gz", - "http://mirror.bazel.build/github.com/01org/mkl-dnn/archive/b01e3a55a07be62172e713bcd2644c5176360212.tar.gz", + "https://mirror.bazel.build/github.com/01org/mkl-dnn/archive/b01e3a55a07be62172e713bcd2644c5176360212.tar.gz", ], sha256 = "0d529ad4c49dc799e6df07c2b88b115d0668735da15fb3b3862d28d33fa68165", strip_prefix = "mkl-dnn-b01e3a55a07be62172e713bcd2644c5176360212", @@ -185,7 +185,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "eigen_archive", urls = [ "https://bitbucket.org/eigen/eigen/get/429aa5254200.tar.gz", - "http://mirror.bazel.build/bitbucket.org/eigen/eigen/get/429aa5254200.tar.gz", + "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/429aa5254200.tar.gz", ], sha256 = "61d8b6fc4279dd1dda986fb1677d15e3d641c07a3ea5abe255790b1f0c0c14e9", strip_prefix = "eigen-eigen-429aa5254200", @@ -198,7 +198,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "970285762565c7890c6c087d262b0a18286e7d0384f13a37786d8521773bc969", strip_prefix = "tools-0e906ebc527eab1cdbf7adabff5b474da9562e9f/arm-bcm2708/arm-rpi-4.9.3-linux-gnueabihf", urls = [ - "http://mirror.bazel.build/github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz", + "https://mirror.bazel.build/github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz", # "https://github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz", ], ) @@ -206,7 +206,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "libxsmm_archive", urls = [ - "http://mirror.bazel.build/github.com/hfp/libxsmm/archive/1.8.1.tar.gz", + "https://mirror.bazel.build/github.com/hfp/libxsmm/archive/1.8.1.tar.gz", # "https://github.com/hfp/libxsmm/archive/1.8.1.tar.gz", ], sha256 = "2ade869c3f42f23b5263c7d594aa3c7e5e61ac6a3afcaf5d6e42899d2a7986ce", @@ -222,7 +222,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "ortools_archive", urls = [ - "http://mirror.bazel.build/github.com/google/or-tools/archive/253f7955c6a1fd805408fba2e42ac6d45b312d15.tar.gz", + "https://mirror.bazel.build/github.com/google/or-tools/archive/253f7955c6a1fd805408fba2e42ac6d45b312d15.tar.gz", # "https://github.com/google/or-tools/archive/253f7955c6a1fd805408fba2e42ac6d45b312d15.tar.gz", ], sha256 = "932075525642b04ac6f1b50589f1df5cd72ec2f448b721fd32234cf183f0e755", @@ -233,7 +233,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "com_googlesource_code_re2", urls = [ - "http://mirror.bazel.build/github.com/google/re2/archive/b94b7cd42e9f02673cd748c1ac1d16db4052514c.tar.gz", + "https://mirror.bazel.build/github.com/google/re2/archive/b94b7cd42e9f02673cd748c1ac1d16db4052514c.tar.gz", # "https://github.com/google/re2/archive/b94b7cd42e9f02673cd748c1ac1d16db4052514c.tar.gz", ], sha256 = "bd63550101e056427c9e7ff12a408c1c8b74e9803f393ca916b2926fc2c4906f", @@ -243,7 +243,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "gemmlowp", urls = [ - "http://mirror.bazel.build/github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip" + "https://mirror.bazel.build/github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip" # "https://github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip", ], sha256 = "dd2557072bde12141419cb8320a9c25e6ec41a8ae53c2ac78c076a347bb46d9d", @@ -253,7 +253,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "farmhash_archive", urls = [ - "http://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz", + "https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz", # "https://github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz", ], sha256 = "6560547c63e4af82b0f202cb710ceabb3f21347a4b996db565a411da5b17aba0", @@ -269,7 +269,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "highwayhash", urls = [ - "http://mirror.bazel.build/github.com/google/highwayhash/archive/dfcb97ca4fe9277bf9dc1802dd979b071896453b.tar.gz", + "https://mirror.bazel.build/github.com/google/highwayhash/archive/dfcb97ca4fe9277bf9dc1802dd979b071896453b.tar.gz", # "https://github.com/google/highwayhash/archive/dfcb97ca4fe9277bf9dc1802dd979b071896453b.tar.gz", ], sha256 = "0f30a15b1566d93f146c8d149878a06e91d9bb7ec2cfd76906df62a82be4aac9", @@ -280,7 +280,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "nasm", urls = [ - "http://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.12.02/nasm-2.12.02.tar.bz2", + "https://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.12.02/nasm-2.12.02.tar.bz2", "http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.12.02.tar.bz2/d15843c3fb7db39af80571ee27ec6fad/nasm-2.12.02.tar.bz2", ], sha256 = "00b0891c678c065446ca59bcee64719d0096d54d6886e6e472aeee2e170ae324", @@ -291,7 +291,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): temp_workaround_http_archive( name = "jpeg", urls = [ - "http://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.1.tar.gz", + "https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.1.tar.gz", # "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.1.tar.gz", ], sha256 = "c15a9607892113946379ccea3ca8b85018301b200754f209453ab21674268e77", @@ -303,7 +303,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "png_archive", urls = [ - "http://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.2.53.tar.gz", + "https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.2.53.tar.gz", # "https://github.com/glennrp/libpng/archive/v1.2.53.tar.gz", ], sha256 = "716c59c7dfc808a4c368f8ada526932be72b2fcea11dd85dc9d88b1df1dfe9c2", @@ -314,7 +314,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "sqlite_archive", urls = [ - "http://mirror.bazel.build/www.sqlite.org/2017/sqlite-amalgamation-3200000.zip", + "https://mirror.bazel.build/www.sqlite.org/2017/sqlite-amalgamation-3200000.zip", "http://www.sqlite.org/2017/sqlite-amalgamation-3200000.zip", ], sha256 = "208780b3616f9de0aeb50822b7a8f5482f6515193859e91ed61637be6ad74fd4", @@ -325,7 +325,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "gif_archive", urls = [ - "http://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz", + "https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz", "http://pilotfiber.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz", ], sha256 = "34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1", @@ -336,7 +336,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "six_archive", urls = [ - "http://mirror.bazel.build/pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", + "https://mirror.bazel.build/pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", "https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", ], sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a", @@ -347,7 +347,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "org_python_pypi_backports_weakref", urls = [ - "http://mirror.bazel.build/pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz", + "https://mirror.bazel.build/pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz", "https://pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz", ], sha256 = "8813bf712a66b3d8b85dc289e1104ed220f1878cf981e2fe756dfaabe9a82892", @@ -358,7 +358,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "com_github_andreif_codegen", urls = [ - "http://mirror.bazel.build/github.com/andreif/codegen/archive/1.0.tar.gz", + "https://mirror.bazel.build/github.com/andreif/codegen/archive/1.0.tar.gz", # "https://github.com/andreif/codegen/archive/1.0.tar.gz", ], sha256 = "2dadd04a2802de27e0fe5a19b76538f6da9d39ff244036afa00c1bba754de5ee", @@ -371,7 +371,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): licenses = ["notice"], # Python 2.0 sha256_urls = { "b5556e921715ddb9242c076cae3963f483aa47266c5e37ea4c187f77cc79501c": [ - "http://mirror.bazel.build/docs.python.org/2.7/_sources/license.txt", + "https://mirror.bazel.build/docs.python.org/2.7/_sources/license.txt", "https://docs.python.org/2.7/_sources/license.txt", ], }, @@ -387,7 +387,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): patched_http_archive( name = "protobuf_archive", urls = [ - "http://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", ], sha256 = "e178a25c52efcb6b05988bdbeace4c0d3f2d2fe5b46696d1d9898875c3803d6a", strip_prefix = "protobuf-b04e5cba356212e4e8c66c61bbe0c3a20537c5b9", @@ -410,7 +410,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "com_google_protobuf", urls = [ - "http://mirror.bazel.build/github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz", ], sha256 = "6d43b9d223ce09e5d4ce8b0060cb8a7513577a35a64c7e3dad10f0703bf3ad93", strip_prefix = "protobuf-0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66", @@ -420,7 +420,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "com_google_protobuf_cc", urls = [ - "http://mirror.bazel.build/github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz", ], sha256 = "6d43b9d223ce09e5d4ce8b0060cb8a7513577a35a64c7e3dad10f0703bf3ad93", strip_prefix = "protobuf-0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66", @@ -429,7 +429,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "nsync", urls = [ - "http://mirror.bazel.build/github.com/google/nsync/archive/ad722c76c6e6653f66be2e1f69521b7f7517da55.tar.gz", + "https://mirror.bazel.build/github.com/google/nsync/archive/ad722c76c6e6653f66be2e1f69521b7f7517da55.tar.gz", # "https://github.com/google/nsync/archive/ad722c76c6e6653f66be2e1f69521b7f7517da55.tar.gz", ], sha256 = "7dd8ca49319f77e8226cd020a9210a525f88ac26e7041c59c95418223a1cdf55", @@ -439,7 +439,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "com_google_googletest", urls = [ - "http://mirror.bazel.build/github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip", + "https://mirror.bazel.build/github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip", # "https://github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip", ], sha256 = "9cbca84c4256bed17df2c8f4d00c912c19d247c11c9ba6647cd6dd5b5c996b8d", @@ -449,7 +449,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "com_github_gflags_gflags", urls = [ - "http://mirror.bazel.build/github.com/gflags/gflags/archive/f8a0efe03aa69b3336d8e228b37d4ccb17324b88.tar.gz", + "https://mirror.bazel.build/github.com/gflags/gflags/archive/f8a0efe03aa69b3336d8e228b37d4ccb17324b88.tar.gz", # "https://github.com/gflags/gflags/archive/f8a0efe03aa69b3336d8e228b37d4ccb17324b88.tar.gz", ], sha256 = "4d222fab8f1ede4709cdff417d15a1336f862d7334a81abf76d09c15ecf9acd1", @@ -465,7 +465,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "pcre", sha256 = "ccdf7e788769838f8285b3ee672ed573358202305ee361cfec7a4a4fb005bbc7", urls = [ - "http://mirror.bazel.build/ftp.exim.org/pub/pcre/pcre-8.39.tar.gz", + "https://mirror.bazel.build/ftp.exim.org/pub/pcre/pcre-8.39.tar.gz", "http://ftp.exim.org/pub/pcre/pcre-8.39.tar.gz", ], strip_prefix = "pcre-8.39", @@ -476,7 +476,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "swig", sha256 = "58a475dbbd4a4d7075e5fe86d4e54c9edde39847cdb96a3053d87cb64a23a453", urls = [ - "http://mirror.bazel.build/ufpr.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz", + "https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz", "http://ufpr.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz", "http://pilotfiber.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz", ], @@ -488,7 +488,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "curl", sha256 = "ff3e80c1ca6a068428726cd7dd19037a47cc538ce58ef61c59587191039b2ca6", urls = [ - "http://mirror.bazel.build/curl.haxx.se/download/curl-7.49.1.tar.gz", + "https://mirror.bazel.build/curl.haxx.se/download/curl-7.49.1.tar.gz", "https://curl.haxx.se/download/curl-7.49.1.tar.gz", ], strip_prefix = "curl-7.49.1", @@ -518,7 +518,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): patched_http_archive( name = "grpc", urls = [ - "http://mirror.bazel.build/github.com/grpc/grpc/archive/781fd6f6ea03645a520cd5c675da67ab61f87e4b.tar.gz", + "https://mirror.bazel.build/github.com/grpc/grpc/archive/781fd6f6ea03645a520cd5c675da67ab61f87e4b.tar.gz", # "https://github.com/grpc/grpc/archive/781fd6f6ea03645a520cd5c675da67ab61f87e4b.tar.gz", ], sha256 = "2004635e6a078acfac8ffa71738397796be4f8fb72f572cc44ecee5d99511d9f", @@ -542,7 +542,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "linenoise", sha256 = "7f51f45887a3d31b4ce4fa5965210a5e64637ceac12720cfce7954d6a2e812f7", urls = [ - "http://mirror.bazel.build/github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz", + "https://mirror.bazel.build/github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz", # "https://github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz", ], strip_prefix = "linenoise-c894b9e59f02203dbe4e2be657572cf88c4230c3", @@ -554,7 +554,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): temp_workaround_http_archive( name = "llvm", urls = [ - "http://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/bb3c660e87f59abb665570a31b01ab125ec4c10e.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/bb3c660e87f59abb665570a31b01ab125ec4c10e.tar.gz", "https://github.com/llvm-mirror/llvm/archive/bb3c660e87f59abb665570a31b01ab125ec4c10e.tar.gz", ], sha256 = "caab6d7978e6771cb4e9b5b89607c5370de8aa642913c6c14e892468194c94e4", @@ -566,7 +566,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "lmdb", urls = [ - "http://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz", + "https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz", # "https://github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz", ], sha256 = "108532fb94c6f227558d45be3f3347b52539f0f58290a7bb31ec06c462d05326", @@ -577,7 +577,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "jsoncpp_git", urls = [ - "http://mirror.bazel.build/github.com/open-source-parsers/jsoncpp/archive/11086dd6a7eba04289944367ca82cea71299ed70.tar.gz", + "https://mirror.bazel.build/github.com/open-source-parsers/jsoncpp/archive/11086dd6a7eba04289944367ca82cea71299ed70.tar.gz", # "https://github.com/open-source-parsers/jsoncpp/archive/11086dd6a7eba04289944367ca82cea71299ed70.tar.gz", ], sha256 = "07d34db40593d257324ec5fb9debc4dc33f29f8fb44e33a2eeb35503e61d0fe2", @@ -593,7 +593,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): patched_http_archive( name = "boringssl", urls = [ - "http://mirror.bazel.build/github.com/google/boringssl/archive/e3860009a091cd1bd2bc189cdbc3c6d095abde84.tar.gz", + "https://mirror.bazel.build/github.com/google/boringssl/archive/e3860009a091cd1bd2bc189cdbc3c6d095abde84.tar.gz", # "https://github.com/google/boringssl/archive/e3860009a091cd1bd2bc189cdbc3c6d095abde84.tar.gz", # 2017-07-07 ], sha256 = "02f5950f93c4fd3691771c07c9d04cf2999ab01383ff99da345249e93b0fcfb2", @@ -605,7 +605,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "zlib_archive", urls = [ - "http://mirror.bazel.build/zlib.net/zlib-1.2.8.tar.gz", + "https://mirror.bazel.build/zlib.net/zlib-1.2.8.tar.gz", "http://zlib.net/fossils/zlib-1.2.8.tar.gz", ], sha256 = "36658cb768a54c1d4dec43c3116c27ed893e88b02ecfcb44f2166f9c0b7f2a0d", @@ -621,7 +621,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "fft2d", urls = [ - "http://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz", + "https://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz", "http://www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz", ], sha256 = "52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296", @@ -631,7 +631,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): temp_workaround_http_archive( name = "snappy", urls = [ - "http://mirror.bazel.build/github.com/google/snappy/archive/1.1.4.tar.gz", + "https://mirror.bazel.build/github.com/google/snappy/archive/1.1.4.tar.gz", # "https://github.com/google/snappy/archive/1.1.4.tar.gz", ], sha256 = "2f7504c73d85bac842e893340333be8cb8561710642fc9562fccdd9d2c3fcc94", @@ -643,7 +643,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): temp_workaround_http_archive( name = "nccl_archive", urls = [ - "http://mirror.bazel.build/github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz", + "https://mirror.bazel.build/github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz", # "https://github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz", ], sha256 = "2ca86fb6179ecbff789cc67c836139c1bbc0324ed8c04643405a30bf26325176", @@ -668,7 +668,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "junit", jar_sha256 = "59721f0805e223d84b90677887d9ff567dc534d7c502ca903c0c2b17f05c116a", jar_urls = [ - "http://mirror.bazel.build/repo1.maven.org/maven2/junit/junit/4.12/junit-4.12.jar", + "https://mirror.bazel.build/repo1.maven.org/maven2/junit/junit/4.12/junit-4.12.jar", "http://repo1.maven.org/maven2/junit/junit/4.12/junit-4.12.jar", "http://maven.ibiblio.org/maven2/junit/junit/4.12/junit-4.12.jar", ], @@ -681,7 +681,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "org_hamcrest_core", jar_sha256 = "66fdef91e9739348df7a096aa384a5685f4e875584cce89386a7a47251c4d8e9", jar_urls = [ - "http://mirror.bazel.build/repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar", + "https://mirror.bazel.build/repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar", "http://repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar", "http://maven.ibiblio.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar", ], @@ -692,7 +692,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): temp_workaround_http_archive( name = "jemalloc", urls = [ - "http://mirror.bazel.build/github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz", + "https://mirror.bazel.build/github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz", # "https://github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz", ], sha256 = "3c8f25c02e806c3ce0ab5fb7da1817f89fc9732709024e2a81b6b82f7cc792a8", @@ -704,7 +704,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "com_google_pprof", urls = [ - "http://mirror.bazel.build/github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz", + "https://mirror.bazel.build/github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz", # "https://github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz", ], sha256 = "e0928ca4aa10ea1e0551e2d7ce4d1d7ea2d84b2abbdef082b0da84268791d0c4", @@ -715,7 +715,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "cub_archive", urls = [ - "http://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.3.zip", + "https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.3.zip", # "https://github.com/NVlabs/cub/archive/1.7.3.zip", ], sha256 = "b7ead9e291d34ffa8074243541c1380d63be63f88de23de8ee548db573b72ebe", @@ -732,7 +732,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "cython", sha256 = "6dcd30b5ceb887b2b965ee7ceb82ea3acb5f0642fe2206c7636b45acea4798e5", urls = [ - "http://mirror.bazel.build/github.com/cython/cython/archive/3732784c45cfb040a5b0936951d196f83a12ea17.tar.gz", + "https://mirror.bazel.build/github.com/cython/cython/archive/3732784c45cfb040a5b0936951d196f83a12ea17.tar.gz", "https://github.com/cython/cython/archive/3732784c45cfb040a5b0936951d196f83a12ea17.tar.gz", ], strip_prefix = "cython-3732784c45cfb040a5b0936951d196f83a12ea17", @@ -742,7 +742,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "bazel_toolchains", urls = [ - "http://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/b2b4b38433bf2d1159360855ea4004378308711b.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/b2b4b38433bf2d1159360855ea4004378308711b.tar.gz", # "https://github.com/bazelbuild/bazel-toolchains/archive/b2b4b38433bf2d1159360855ea4004378308711b.tar.gz", ], sha256 = "46187270ca04ff8109980f45c3438fabfe48695e163789096eb82ee097ffe685", From e0e4f693978dcaf5bf4ecbc18e6926bdf33b2870 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Thu, 19 Oct 2017 16:28:57 -0700 Subject: [PATCH 11/41] [tf.contrib.seq2seq] Reserve -1s in GatherTree for error states. GatherTree now emits end_token after the first decoded end_token in the path, instead of -1s at the end of each sequence. PiperOrigin-RevId: 172816652 --- .../seq2seq/kernels/beam_search_ops.cc | 7 +++-- .../seq2seq/kernels/beam_search_ops_gpu.cu.cc | 25 +++++++++------ .../contrib/seq2seq/ops/beam_search_ops.cc | 11 ++++--- .../kernel_tests/beam_search_ops_test.py | 31 ++++++++++--------- 4 files changed, 45 insertions(+), 29 deletions(-) diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc index 95273e2b33e..64973ccccdc 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc @@ -112,7 +112,7 @@ struct GatherTree { const int32 max_time = parent_ids.dimension(0); const int32 batch_size = parent_ids.dimension(1); const int32 beam_width = parent_ids.dimension(2); - beams.setConstant(-1); + beams.setConstant(end_token); auto DoWork = [&, ctx, end_token](int start_batch_beam, int limit_batch_beam) { @@ -138,10 +138,13 @@ struct GatherTree { beams(level, batch, beam) = step_ids(level, batch, parent); parent = parent_ids(level, batch, parent); } + // Not necessary when using a BeamSearchDecoder, but necessary + // when a user feeds in possibly broken trajectory (i.e., non-eos + // entries in a beam following eos entries). bool finished = false; for (int32 time = 0; time < max_seq_len_b; ++time) { if (finished) { - beams(time, batch, beam) = -1; + beams(time, batch, beam) = end_token; } else if (beams(time, batch, beam) == end_token) { finished = true; } diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc index e71efc48cec..bc28d492fe1 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc @@ -46,24 +46,31 @@ __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time, const int32 initial_beam_ix = GET_IX(max_seq_len_b - 1, beam); beams[initial_beam_ix] = ldg(step_ids + initial_beam_ix); int32 parent = ldg(parent_ids + initial_beam_ix); + bool found_bad = false; for (int32 level = max_seq_len_b - 2; level >= 0; --level) { const int32 level_beam_ix = GET_IX(level, beam); const int32 level_parent_ix = GET_IX(level, parent); if (parent < 0 || parent > beam_width) { beams[level_beam_ix] = -1; parent = -1; + found_bad = true; } else { beams[level_beam_ix] = ldg(step_ids + level_parent_ix); parent = ldg(parent_ids + level_parent_ix); } } - bool finished = false; - for (int32 time = 0; time < max_seq_len_b; ++time) { - const int32 level_beam_ix = GET_IX(time, beam); - if (finished) { - beams[level_beam_ix] = -1; - } else if (beams[level_beam_ix] == end_token) { - finished = true; + // Not necessary when using a BeamSearchDecoder, but necessary + // when a user feeds in possibly broken trajectory (i.e., non-eos + // entries in a beam following eos entries). + if (!found_bad) { + bool finished = false; + for (int32 time = 0; time < max_seq_len_b; ++time) { + const int32 level_beam_ix = GET_IX(time, beam); + if (finished) { + beams[level_beam_ix] = end_token; + } else if (beams[level_beam_ix] == end_token) { + finished = true; + } } } #undef GET_IX @@ -80,8 +87,8 @@ struct GatherTree { const int32 max_time = parent_ids.dimension(0); const int32 batch_size = parent_ids.dimension(1); const int32 beam_width = parent_ids.dimension(2); - // First kernel launch to zero things out - beams.device(d) = beams.constant(T(-1)); + // First kernel launch to "zero" things out + beams.device(d) = beams.constant(end_token); CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * beam_width, d); // clang-format off diff --git a/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc b/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc index 231504bfbb3..71539b6f592 100644 --- a/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc +++ b/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc @@ -53,11 +53,14 @@ REGISTER_OP("GatherTree") .Doc(R"doc( Calculates the full beams from the per-step ids and parent beam ids. -This op implements the following mathematical equations: +On CPU, if an out of bound parent id is found, an error is returned. +On GPU, if an out of bound parent id is found, a -1 is stored in the +corresponding output value and the execution for that beam returns early. -```python -TODO(ebrevdo): fill in -``` +For a given beam, past the time step containing the first decoded `end_token` +all values are filled in with `end_token`. + +TODO(ebrevdo): fill in the remainder of this docstring. step_ids: `[max_time, batch_size, beam_width]`. parent_ids: `[max_time, batch_size, beam_width]`. diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py index f3013148720..277c5b6ef76 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py @@ -36,24 +36,26 @@ class GatherTreeTest(test.TestCase): def testGatherTreeOne(self): # (max_time = 4, batch_size = 1, beams = 3) + end_token = 10 step_ids = _transpose_batch_time( [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]]) max_sequence_lengths = [3] - expected_result = _transpose_batch_time( - [[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) + expected_result = _transpose_batch_time([[[2, 2, 2], [6, 5, 6], [7, 8, 9], + [10, 10, 10]]]) beams = beam_search_ops.gather_tree( step_ids=step_ids, parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, - end_token=10) + end_token=end_token) with self.test_session(use_gpu=True): self.assertAllEqual(expected_result, beams.eval()) def testBadParentValuesOnCPU(self): # (batch_size = 1, max_time = 4, beams = 3) # bad parent in beam 1 time 1 + end_token = 10 step_ids = _transpose_batch_time( [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( @@ -64,7 +66,7 @@ class GatherTreeTest(test.TestCase): step_ids=step_ids, parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, - end_token=10) + end_token=end_token) with self.test_session(): with self.assertRaisesOpError( r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"): @@ -77,19 +79,20 @@ class GatherTreeTest(test.TestCase): return # (max_time = 4, batch_size = 1, beams = 3) # bad parent in beam 1 time 1; appears as a negative index at time 0 + end_token = 10 step_ids = _transpose_batch_time( [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) max_sequence_lengths = [3] - expected_result = _transpose_batch_time( - [[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) + expected_result = _transpose_batch_time([[[2, -1, 2], [6, 5, 6], [7, 8, 9], + [10, 10, 10]]]) with ops.device("/device:GPU:0"): beams = beam_search_ops.gather_tree( step_ids=step_ids, parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, - end_token=10) + end_token=end_token) with self.test_session(use_gpu=True): self.assertAllEqual(expected_result, beams.eval()) @@ -115,24 +118,24 @@ class GatherTreeTest(test.TestCase): self.assertEqual((max_time, batch_size, beam_width), beams.shape) beams_value = beams.eval() for b in range(batch_size): - # Past max_sequence_lengths[b], we emit all -1s. + # Past max_sequence_lengths[b], we emit all end tokens. b_value = beams_value[max_sequence_lengths[b]:, b, :] - self.assertAllClose(b_value, -1. * np.ones_like(b_value)) + self.assertAllClose(b_value, end_token * np.ones_like(b_value)) for batch, beam in itertools.product( range(batch_size), range(beam_width)): v = np.squeeze(beams_value[:, batch, beam]) if end_token in v: + found_bad = np.where(v == -1)[0] + self.assertEqual(0, len(found_bad)) found = np.where(v == end_token)[0] - # Should be up to 1 instance of end_token per beam. - self.assertEqual(len(found), 1) - found = found[0] + found = found[0] # First occurrence of end_token. # If an end_token is found, everything before it should be a # valid id and everything after it should be -1. if found > 0: self.assertAllEqual( v[:found - 1] >= 0, np.ones_like(v[:found - 1], dtype=bool)) - self.assertAllClose( - v[found + 1:], -1 * np.ones_like(v[found + 1:])) + self.assertAllClose(v[found + 1:], + end_token * np.ones_like(v[found + 1:])) if __name__ == "__main__": From 2977dccc96c343ca85cb00b50672b36c99656532 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Thu, 19 Oct 2017 16:42:14 -0700 Subject: [PATCH 12/41] Context-specific C API to set options other than configproto (still unused) PiperOrigin-RevId: 172818175 --- tensorflow/c/eager/c_api.cc | 17 ++++++++++---- tensorflow/c/eager/c_api.h | 20 ++++++++++++++-- tensorflow/c/eager/c_api_internal.h | 4 ++++ tensorflow/c/eager/c_api_test.cc | 36 ++++++++++++++--------------- tensorflow/python/eager/context.py | 16 ++++++++----- tensorflow/python/pywrap_tfe.i | 16 ++++++++++++- 6 files changed, 78 insertions(+), 31 deletions(-) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 514a4010bc8..334c02bff9a 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -54,9 +54,18 @@ string DeviceName(tensorflow::Device* d) { extern "C" { -TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) { +TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; } + +void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto, + size_t proto_len, TF_Status* status) { + TF_SetConfig(&options->session_options, proto, proto_len, status); +} + +void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } + +TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { TF_Graph* graph = TF_NewGraph(); - TF_Session* session = TF_NewSession(graph, opts, status); + TF_Session* session = TF_NewSession(graph, &opts->session_options, status); if (status->status.ok()) { if (session->device_mgr == nullptr || session->devices.empty()) { status->status = tensorflow::errors::InvalidArgument( @@ -72,8 +81,8 @@ TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) { TFE_Context* ret = new TFE_Context(session); ret->pflr.reset(new tensorflow::ProcessFunctionLibraryRuntime( - ret->session->device_mgr, opts->options.env, TF_GRAPH_DEF_VERSION, - &ret->func_lib_def, {})); + ret->session->device_mgr, opts->session_options.options.env, + TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {})); ret->rendezvous = new tensorflow::IntraProcessRendezvous(ret->session->device_mgr); diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 9bfa63711b5..201cb222c92 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -43,14 +43,30 @@ limitations under the License. extern "C" { #endif +typedef struct TFE_ContextOptions TFE_ContextOptions; + +// Return a new options object. +TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions(); + +// Set the config in TF_ContextOptions.options. +// config should be a serialized tensorflow.ConfigProto proto. +// If config was not parsed successfully as a ConfigProto, record the +// error information in *status. +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetConfig( + TFE_ContextOptions* options, const void* proto, size_t proto_len, + TF_Status* status); + +// Destroy an options object. +TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*); + // "Context" under which operations/functions are executed. It encapsulates // things like the available devices, resource manager etc. // // TODO(ashankar): Merge with TF_Session? typedef struct TFE_Context TFE_Context; -TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, - TF_Status* status); +TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext( + const TFE_ContextOptions* opts, TF_Status* status); TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status); TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 712526f1700..7a440a5a7e8 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -35,6 +35,10 @@ limitations under the License. #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" +struct TFE_ContextOptions { + TF_SessionOptions session_options; +}; + struct TFE_Context { explicit TFE_Context(TF_Session* s) : session(s) {} diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 72e0fe8a156..5344956ee77 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -62,10 +62,10 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { void BM_InitOp(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TFE_TensorHandle* m = TestMatrixTensorHandle(); tensorflow::testing::StartTiming(); @@ -84,10 +84,10 @@ BENCHMARK(BM_InitOp); void BM_Execute(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_Op* matmul = MatMulOp(ctx, m, m); @@ -109,9 +109,9 @@ BENCHMARK(BM_Execute); TEST(CAPI, Context) { TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TF_DeviceList* devices = TFE_ContextListDevices(ctx, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -150,9 +150,9 @@ TEST(CAPI, TensorHandle) { TEST(CAPI, TensorHandleCopyBetweenDevices) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status.get()); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); @@ -218,10 +218,10 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) { TEST(CAPI, Execute) { TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_Op* matmul = MatMulOp(ctx, m, m); @@ -285,10 +285,10 @@ string MatMulFunction() { TEST(CAPI, FunctionDefAndExecute) { TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); string function_def = MatMulFunction(); TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), @@ -326,10 +326,10 @@ TEST(CAPI, FunctionDefAndExecute) { void BM_ExecuteFunction(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); string function_def = MatMulFunction(); TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), @@ -406,10 +406,10 @@ TEST(CAPI, Variables) { // Variables use resource handles, so this is really a test for resource // tensor handling. TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -446,10 +446,10 @@ TEST(CAPI, Variables) { void BM_ReadVariable(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index aa7cba56def..58581283d27 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -26,7 +26,6 @@ import threading from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import device as pydev from tensorflow.python.framework import errors -from tensorflow.python.util import compat from tensorflow.python.util import tf_contextlib GRAPH_MODE = 0 @@ -103,11 +102,16 @@ class Context(object): if self._context_handle is not None: return assert self._context_devices is None - opts = pywrap_tensorflow.TF_NewSessionOptions( - target=compat.as_bytes(""), config=self._config) - with errors.raise_exception_on_not_ok_status() as status: - self._context_handle = pywrap_tensorflow.TFE_NewContext(opts, status) - pywrap_tensorflow.TF_DeleteSessionOptions(opts) + opts = pywrap_tensorflow.TFE_NewContextOptions() + try: + with errors.raise_exception_on_not_ok_status() as status: + if self._config is not None: + config_str = self._config.SerializeToString() + pywrap_tensorflow.TFE_ContextOptionsSetConfig( + opts, config_str, len(config_str), status) + self._context_handle = pywrap_tensorflow.TFE_NewContext(opts, status) + finally: + pywrap_tensorflow.TFE_DeleteContextOptions(opts) # Store list of devices self._context_devices = [] with errors.raise_exception_on_not_ok_status() as status: diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 5c624a9c126..36c09c20c21 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -30,12 +30,25 @@ limitations under the License. %rename("%s") TFE_Py_TapeDeleteTrace; %rename("%s") TFE_Py_TapeRecordOperation; %rename("%s") TFE_Py_TapeExport; - +%rename("%s") TFE_NewContextOptions; +%rename("%s") TFE_ContextOptionsSetConfig; +%rename("%s") TFE_DeleteContextOptions; %{ #include "tensorflow/python/eager/pywrap_tfe.h" %} +%typemap(in) (const void* proto) { + char* c_string; + Py_ssize_t py_size; + // PyBytes_AsStringAndSize() does not copy but simply interprets the input + if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) { + // Python has raised an error (likely TypeError or UnicodeEncodeError). + SWIG_fail; + } + $1 = static_cast(c_string); +} + %typemap(out) TF_DataType { $result = PyInt_FromLong($1); } @@ -165,3 +178,4 @@ limitations under the License. %typemap(in, numinputs=0) TF_Status *out_status; %typemap(freearg) (TF_Status* out_status); %typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status); +%typemap(in) (const void* proto); From e885d1abdce5db4a67e0b3ba85dbcc708f856645 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Thu, 19 Oct 2017 16:42:45 -0700 Subject: [PATCH 13/41] One less error message in gradients_function PiperOrigin-RevId: 172818233 --- tensorflow/python/eager/backprop.py | 11 ++++------- tensorflow/python/eager/backprop_test.py | 8 ++++++++ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index da17be05b7d..9580e848475 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -396,12 +396,11 @@ def implicit_grad(f): return grad_fn -def _get_arg_spec(f, params): +def _get_arg_spec(f, params, param_args): args = tf_inspect.getargspec(f).args if params is None: if not args: - raise ValueError("When params is None the differentiated function cannot" - " only take arguments by *args and **kwds.") + return range(len(param_args)) return range(len(args)) elif all(isinstance(x, six.string_types) for x in params): return [args.index(n) for n in params] @@ -560,10 +559,9 @@ def val_and_grad_function(f, params=None): ValueError: if the params are not all strings or all integers. """ - parameter_positions = _get_arg_spec(f, params) - def decorated(*args, **kwds): """Computes the value and gradient of the decorated function.""" + parameter_positions = _get_arg_spec(f, params, args) dy = kwds.pop("dy", None) if dy is not None: dy = ops.convert_to_tensor(dy) @@ -616,10 +614,9 @@ def make_vjp(f, params=None): """ - parameter_positions = _get_arg_spec(f, params) - def decorated(*args, **kwds): """Computes the value and gradient of the decorated function.""" + parameter_positions = _get_arg_spec(f, params, args) assert not kwds, "The gradient function can't take keyword arguments." tape.push_new_tape() sources = [] diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 95d5f0adcb4..7da8eb0c9b5 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -381,6 +381,14 @@ class BackpropTest(test.TestCase): [tensor_shape.TensorShape(s).as_proto() for s in shape_list], backprop.make_attr([pywrap_tensorflow.TF_ATTR_SHAPE], shape_list)) + def testArgsGradientFunction(self): + + def f(*args): + return args[0] * args[0] + + grad = backprop.gradients_function(f) + self.assertAllEqual(grad(1.0)[0], 2.0) + def testMultiValueConvertToTensor(self): x = resource_variable_ops.ResourceVariable( initial_value=array_ops.constant([1.0]), name='x') From f1054553eafc74df8be9425c3344e71af98962ad Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Oct 2017 17:31:49 -0700 Subject: [PATCH 14/41] Add missing backslash in macro in mkl_transpose_op.cc. Fix erroneous formatting that resulted from it. Fix return type for function template MKLTranspose2D. Define MKL_Complex8 and MKL_Complex16 macros before including the MKL headers. Only conjugate but don't transpose if conjugate=true && perm[0] == 0 && perm[1] == 1. PiperOrigin-RevId: 172824073 --- tensorflow/core/kernels/mkl_transpose_op.cc | 118 ++++++++++---------- 1 file changed, 62 insertions(+), 56 deletions(-) diff --git a/tensorflow/core/kernels/mkl_transpose_op.cc b/tensorflow/core/kernels/mkl_transpose_op.cc index 89a1d5e8a7d..764d4c9400e 100644 --- a/tensorflow/core/kernels/mkl_transpose_op.cc +++ b/tensorflow/core/kernels/mkl_transpose_op.cc @@ -18,6 +18,9 @@ limitations under the License. #ifdef INTEL_MKL #define EIGEN_USE_THREADS +#include "tensorflow/core/framework/numeric_types.h" +#define MKL_Complex8 tensorflow::complex64 +#define MKL_Complex16 tensorflow::complex128 #include "mkl_trans.h" #include "tensorflow/core/kernels/transpose_functor.h" #include "tensorflow/core/kernels/transpose_op.h" @@ -41,7 +44,7 @@ namespace tensorflow { namespace { template -void MKLTranspose2D(const char trans, const Tensor& in, Tensor* out) {} +Status MKLTranspose2D(const char trans, const Tensor& in, Tensor* out); // Documentation here: https://software.intel.com/en-us/node/520863 // Parameters: (ordering:row-major, operation:transpose, num_rows, num_cols, @@ -54,70 +57,73 @@ void MKLTranspose2D(const char trans, const Tensor& in, Tensor* out) {} mkl_##PREFIX##omatcopy('R', trans, in.dim_size(0), in.dim_size(1), 1, \ in.flat().data(), in.dim_size(1), \ out->flat().data(), in.dim_size(0)); \ - return Status::OK(); + return Status::OK(); \ } - INSTANTIATE(float, s) - INSTANTIATE(double, d) - INSTANTIATE(complex64, c) - INSTANTIATE(complex128, z) +INSTANTIATE(float, s) +INSTANTIATE(double, d) +INSTANTIATE(complex64, c) +INSTANTIATE(complex128, z) #undef INSTANTIATE - static const char kMKLTranspose = 'T'; - static const char kMKLConjugateTranspose = 'C'; +static const char kMKLTranspose = 'T'; +static const char kMKLConjugateTranspose = 'C'; - } // namespace tensorflow +} // namespace - Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, - gtl::ArraySlice perm, - Tensor* out) { - if (in.dims() == 2) { - switch (in.dtype()) { - case DT_FLOAT: - return MKLTranspose2D(kMKLTranspose, in, out); - case DT_DOUBLE: - return MKLTranspose2D(kMKLTranspose, in, out); - case DT_COMPLEX64: - return MKLTranspose2D(kMKLTranspose, in, out); - case DT_COMPLEX128: - return MKLTranspose2D(kMKLTranspose, in, out); - default: - break; - } +Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, + gtl::ArraySlice perm, + Tensor* out) { + if (in.dims() == 2) { + if (perm[0] == 0 && perm[1] == 1) { + return Status::OK(); } - // Fallback to eigen if transpose parameters not supported by MKL - typedef Eigen::ThreadPoolDevice CPUDevice; - return ::tensorflow::DoTranspose(ctx->eigen_device(), in, perm, - out); - } - - Status MklConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx, - const Tensor& in, - gtl::ArraySlice perm, - Tensor* out) { - if (in.dims() == 2) { - // TODO(rmlarsen): By setting lda and ldb, we could use the MKL kernels - // for any transpose that can be reduced to swapping the last two - // dimensions in a rank-3 tensor. We can even run each outer dimension in - // a separate thread. - switch (in.dtype()) { - case DT_FLOAT: - return MKLTranspose2D(kMKLTranspose, in, out); - case DT_DOUBLE: - return MKLTranspose2D(kMKLTranspose, in, out); - case DT_COMPLEX64: - return MKLTranspose2D(kMKLConjugateTranspose, in, out); - case DT_COMPLEX128: - return MKLTranspose2D(kMKLConjugateTranspose, in, out); - default: - break; - } + switch (in.dtype()) { + case DT_FLOAT: + return MKLTranspose2D(kMKLTranspose, in, out); + case DT_DOUBLE: + return MKLTranspose2D(kMKLTranspose, in, out); + case DT_COMPLEX64: + return MKLTranspose2D(kMKLTranspose, in, out); + case DT_COMPLEX128: + return MKLTranspose2D(kMKLTranspose, in, out); + default: + break; } - // Fallback to eigen if transpose parameters not supported by MKL - typedef Eigen::ThreadPoolDevice CPUDevice; - return ::tensorflow::DoConjugateTranspose(ctx->eigen_device(), - in, perm, out); } + // Fallback to eigen if transpose parameters not supported by MKL + typedef Eigen::ThreadPoolDevice CPUDevice; + return ::tensorflow::DoTranspose(ctx->eigen_device(), in, perm, + out); +} + +Status MklConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx, + const Tensor& in, + gtl::ArraySlice perm, + Tensor* out) { + if (in.dims() == 2 && perm[0] == 1 && perm[1] == 0) { + // TODO(rmlarsen): By setting lda and ldb, we could use the MKL kernels + // for any transpose that can be reduced to swapping the last two + // dimensions in a rank-3 tensor. We can even run each outer dimension in + // a separate thread. + switch (in.dtype()) { + case DT_FLOAT: + return MKLTranspose2D(kMKLTranspose, in, out); + case DT_DOUBLE: + return MKLTranspose2D(kMKLTranspose, in, out); + case DT_COMPLEX64: + return MKLTranspose2D(kMKLConjugateTranspose, in, out); + case DT_COMPLEX128: + return MKLTranspose2D(kMKLConjugateTranspose, in, out); + default: + break; + } + } + // Fallback to eigen if transpose parameters not supported by MKL + typedef Eigen::ThreadPoolDevice CPUDevice; + return ::tensorflow::DoConjugateTranspose(ctx->eigen_device(), in, + perm, out); +} } // namespace tensorflow From e7654b99c46a479d61c1fd96a9f4710682acf4da Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Thu, 19 Oct 2017 17:44:32 -0700 Subject: [PATCH 15/41] Adds tfe.IsolateTest, an Eager-agnostic abstraction for isolating resources Switches Eager unit tests to use IsolateTest, so their resources have unique container names. PiperOrigin-RevId: 172825317 --- tensorflow/contrib/eager/python/tfe.py | 2 + tensorflow/python/eager/graph_callable.py | 10 +++ tensorflow/python/framework/test_util.py | 64 ++++++++++++++++- tensorflow/python/framework/test_util_test.py | 71 +++++++++++++++++++ .../resource_variable_ops_test.py | 11 ++- .../python/ops/resource_variable_ops.py | 11 +++ 6 files changed, 164 insertions(+), 5 deletions(-) diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 25942aadfbb..4ed258f6ffb 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -53,6 +53,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@in_eager_mode @@in_graph_mode +@@IsolateTest @@run_test_in_graph_and_eager_modes """ @@ -84,6 +85,7 @@ from tensorflow.python.eager.execution_callbacks import nan_callback from tensorflow.python.eager.execution_callbacks import seterr from tensorflow.python.framework.ops import enable_eager_execution from tensorflow.python.framework.ops import eager_run as run +from tensorflow.python.framework.test_util import IsolateTest from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py index 3aba164630d..0ec83636a0f 100644 --- a/tensorflow/python/eager/graph_callable.py +++ b/tensorflow/python/eager/graph_callable.py @@ -312,11 +312,21 @@ def _graph_callable_internal(func, shape_and_dtypes): Returns: Callable graph object. """ + container = tf_ops.get_default_graph()._container # pylint: disable=protected-access + container_prefix = tf_ops.get_default_graph()._container_prefix # pylint: disable=protected-access with context.graph_mode(): # This graph will store both the initialization and the call version of the # wrapped function. It will later be used by the backprop code to build the # backprop graph, if necessary. tmp_graph = tf_ops.Graph() + # Inherit the container from the original graph to create resources at user + # expected containers. Also inherits the container prefix, since this is + # used for error checking when isolating Eager execution (the container + # prefix at creation must match the container prefix when used, and + # variables returned from the graph callable will be used in the outside + # context). + tmp_graph._container = container # pylint: disable=protected-access + tmp_graph._container_prefix = container_prefix # pylint: disable=protected-access with tmp_graph.as_default(): # Placeholders for the non-variable inputs. func_inputs = _get_graph_callable_inputs(shape_and_dtypes) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index c681ffb514c..a01bf02deb4 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -47,6 +47,7 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.client import device_lib from tensorflow.python.client import session from tensorflow.python.eager import context +from tensorflow.python.eager import tape from tensorflow.python.framework import device as pydev from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -391,6 +392,66 @@ def with_c_api(cls): return cls +class IsolateTest(object): + """A context manager which isolates resources in its block. + + Provides an Eager-agnostic abstraction for preventing the sharing of + variables and other resources. + + In graph mode, resource handle ops are only executed in a particular Session, + isolating them from resources with the same name in other Graphs. In Eager, + separate Sessions do not exist, so resources (particularly ResourceVariables) + would be shared implicitly if a resource of the same name were created + anywhere in a Python process. Multiple handles to the same resource would + cause several issues, and so this type of sharing will raise an exception. + + Using resources with the same name in a single Python process may be useful + (especially for unit tests), so this context manager provides an abstraction + for isolating resources. Using a resource created in one Isolation environment + in another is an error. + + Example usage in Eager mode: + + ```python + import tensorflow as tf + # Import subject to change + from tensorflow.contrib.eager.python import tfe + + tfe.enable_eager_execution() + + for hyperparameter in [1, 2, 3]: + with tfe.IsolateTest(): + v = tfe.Variable(name="v", initial_value=hyperparameter) + # train model, test results ... + ``` + + IsolateTest is currently exposed through contrib.eager, but it creates a new + default Graph and provides equivalent safety in graph mode. + """ + + def __init__(self): + if context.in_eager_mode() and tape.could_possibly_record(): + raise ValueError("Cannot isolate Eager execution with an active tape.") + # In Eager, Graphs set a container which isolates resources, and maintain a + # VariableStore which caches ResourceVariable objects created through + # get_variable. So setting the default Graph has the side effect of + # isolating Eager resources. + with context.eager_mode(): + # Create the graph in Eager mode, as this provides stricter semantics + # (i.e. has a unique container prefix). This prevents implicit sharing + # when a Graph-mode graph is created and then Eager mode is enabled (an + # error through enable_eager_execution, but common with context managers + # in unit tests). + self._graph_as_default_context_manager = ops.Graph().as_default() + + def __enter__(self): + self._graph_as_default_context_manager.__enter__() + + def __exit__(self, type_arg, value_arg, traceback_arg): + return self._graph_as_default_context_manager.__exit__( + type_arg, value_arg, traceback_arg) + + def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None, use_gpu=False, force_gpu=False, reset_test=True): @@ -440,9 +501,8 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None, with context.device("/device:CPU:0"): f(self, **kwargs) - eager_graph = graph or ops.Graph() with context.eager_mode(): - with eager_graph.as_default(): + with IsolateTest(): run_eager_mode() return decorated diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index 6129fa2e0d0..b2f8d62095f 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -27,12 +27,16 @@ from google.protobuf import text_format from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.client import session +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -325,5 +329,72 @@ class TestUtilTest(test_util.TensorFlowTestCase): self.assertEqual(a_rand, b_rand) +@test_util.with_c_api +class IsolationTest(test_util.TensorFlowTestCase): + + @test_util.run_in_graph_and_eager_modes() + def test_variable_reuse_exception(self): + with test_util.IsolateTest(), session.Session(): + first_container_variable = resource_variable_ops.ResourceVariable( + name="first_container_variable", + initial_value=1) + if context.in_graph_mode(): + self.evaluate([variables.global_variables_initializer()]) + with test_util.IsolateTest(): + if context.in_graph_mode(): + with self.assertRaises(RuntimeError): + self.evaluate(first_container_variable.read_value()) + else: + with self.assertRaises(ValueError): + first_container_variable.read_value() + + @test_util.run_in_graph_and_eager_modes() + def test_variable_reuse_exception_nested(self): + with test_util.IsolateTest(), session.Session(): + first_container_variable = resource_variable_ops.ResourceVariable( + name="first_container_variable", + initial_value=1) + if context.in_graph_mode(): + self.evaluate([variables.global_variables_initializer()]) + with test_util.IsolateTest(), session.Session(): + if context.in_graph_mode(): + with self.assertRaises(RuntimeError): + self.evaluate(first_container_variable.read_value()) + else: + with self.assertRaises(ValueError): + first_container_variable.read_value() + + @test_util.run_in_graph_and_eager_modes() + def test_no_sharing(self): + with test_util.IsolateTest(), session.Session(): + first_container_variable = resource_variable_ops.ResourceVariable( + name="same_name", + initial_value=1) + if context.in_graph_mode(): + self.evaluate([variables.global_variables_initializer()]) + with test_util.IsolateTest(), session.Session(): + second_container_variable = resource_variable_ops.ResourceVariable( + name="same_name", + initial_value=2) + if context.in_graph_mode(): + self.evaluate([variables.global_variables_initializer()]) + self.assertEqual( + 2, self.evaluate(second_container_variable.read_value())) + self.assertEqual(1, self.evaluate(first_container_variable.read_value())) + + def test_graph_mode_isolation(self): + with context.graph_mode(): + # Even if we've (accidentally) called IsolateTest in Graph mode, it should + # provide Eager isolation. + with test_util.IsolateTest(): + with context.eager_mode(): + first_container_variable = resource_variable_ops.ResourceVariable( + name="first_container_variable", + initial_value=1) + with context.eager_mode(): + with self.assertRaises(ValueError): + first_container_variable.read_value() + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 23676223dc6..cf4b61674fc 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -309,12 +309,15 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.evaluate(variables.global_variables_initializer()) w = resource_variable_ops.var_handle_op( - dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4") + dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4", + # Needed in Eager since we get a unique container name by default. + container=ops.get_default_graph()._container) w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, self.evaluate(w_read)) x = resource_variable_ops.var_handle_op( - dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5") + dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5", + container=ops.get_default_graph()._container) with self.assertRaisesOpError("Resource .*/var5/.* does not exist"): x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype) self.evaluate(x_read) @@ -328,7 +331,9 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.evaluate(variables.global_variables_initializer()) w = resource_variable_ops.var_handle_op( - dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var6") + dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var6", + # Needed in Eager since we get a unique container name by default. + container=ops.get_default_graph()._container) w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, self.evaluate(w_read)) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index dd3f167145a..aa45752a9d4 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -270,6 +270,9 @@ class ResourceVariable(variables.Variable): collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] self._save_slice_info = None self._in_graph_mode = context.in_graph_mode() + # Save the graph's container prefix for error checking. Reading the value of + # the ResourceVariable from another Graph in Eager mode is an error. + self._container_prefix = ops.get_default_graph()._container_prefix # pylint: disable=protected-access with ops.control_dependencies(None): with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: @@ -577,7 +580,15 @@ class ResourceVariable(variables.Variable): Returns: the read operation. + Raises: + ValueError: if the ResourceVariable was created in another isolation + environment or graph. """ + if (not self._in_graph_mode and + self._container_prefix != ops.get_default_graph()._container_prefix): # pylint: disable=protected-access + raise ValueError( + "Attempted to read a variable from another isolation environment" + " or Graph") with ops.name_scope("Read"): # Ensure we read the variable in the same device as the handle. with ops.device(self._handle_device): From db07ee27b75f5efecf3f3706ec1a11e4cd05da54 Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Thu, 19 Oct 2017 17:58:38 -0700 Subject: [PATCH 16/41] Fix bug introduced in https://github.com/tensorflow/tensorflow/commit/dc442f4ce2d3b11b56721337fe2b9e2282be93be Potentially invalid pointers passed to GraphConstructor::Construct() PiperOrigin-RevId: 172826567 --- tensorflow/core/graph/graph_constructor.cc | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 92b48432210..b2c193b050b 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -1068,10 +1068,16 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, refiner->set_graph_def_version( std::min(refiner->graph_def_version(), gdef.versions().producer())); - return GraphConstructor::Construct( - opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner, - &results->return_tensors, &results->return_nodes, - &results->unused_input_map_keys); + if (results == nullptr) { + return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(), + &gdef.library(), g, refiner, nullptr, + nullptr, nullptr); + } else { + return GraphConstructor::Construct( + opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner, + &results->return_tensors, &results->return_nodes, + &results->unused_input_map_keys); + } } void CopyGraph(const Graph& src, Graph* dest) { From fa4d04ab99d45eb317e39c1a6b8848bbc47ebe0e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Oct 2017 17:58:50 -0700 Subject: [PATCH 17/41] Address Metrics TODOs, in particular we'd like them to work in Graph mode. To get this to work, we add support for capturing tensors from outside the function graph in graph mode to eager/function.py. Also get unique names and variable scopes working. PiperOrigin-RevId: 172826589 --- tensorflow/contrib/eager/python/BUILD | 9 +- .../contrib/eager/python/evaluator_test.py | 2 +- .../contrib/eager/python/metrics_impl.py | 156 ++++++++++++------ .../contrib/eager/python/metrics_test.py | 51 ++++++ tensorflow/python/eager/function.py | 61 +++++-- tensorflow/python/eager/function_test.py | 18 ++ tensorflow/tools/ci_build/ci_sanity.sh | 1 + 7 files changed, 234 insertions(+), 64 deletions(-) diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 0c61630aa8f..702136e3e4f 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -132,11 +132,12 @@ py_library( "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", "//tensorflow/python:layers_base", "//tensorflow/python:math_ops", "//tensorflow/python:util", "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:function", ], ) @@ -146,6 +147,10 @@ py_test( srcs_version = "PY2AND3", deps = [ ":metrics", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], ) @@ -160,6 +165,8 @@ py_library( deps = [ ":datasets", ":metrics", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:function", ], ) diff --git a/tensorflow/contrib/eager/python/evaluator_test.py b/tensorflow/contrib/eager/python/evaluator_test.py index 099e10e2307..b18463c31a7 100644 --- a/tensorflow/contrib/eager/python/evaluator_test.py +++ b/tensorflow/contrib/eager/python/evaluator_test.py @@ -86,7 +86,7 @@ class EvaluatorTest(test.TestCase): for v in e.metric_variables: p = v.name.split("/")[0] prefix_count[p] = prefix_count.get(p, 0) + 1 - self.assertEqual({"outer-mean": 2, "mean": 2}, prefix_count) + self.assertEqual({"outer_mean": 2, "mean": 2}, prefix_count) def testDataset(self): e = SimpleEvaluator(IdentityModel()) diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 63a0f8d9a45..2a624b218cc 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -18,6 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import re + +from tensorflow.python.eager import context +from tensorflow.python.eager import function from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops @@ -25,55 +29,69 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope +_to_replace = re.compile("[^A-Za-z0-9.]") + + class Metric(object): """A metric holds state for aggregating statistics over an evaluation run. Users will use Evaluator.add_metric() to add Metric objects to their - evaluation, call them in each step, and then use - Evaluator.all_metric_results() at the end. + evaluation, call them in each step (treating the object as a callable), + and then use Evaluator.all_metric_results() at the end. Descendants will implement: - * call(): Should follow this pattern: - if not self.built: - self.var = self.add_variable(...) - self.add_update(self.var.assign_add(...)) - * aggregate(): Adds in the state from a list of metrics of the same type - as `self`. (Default of summing all the variables will be fine for most - descendants.) - * result(): Computes and returns a final value for the metric + * `build()`: All variables should be created in this method, by calling + `self.add_variable()` as in: `self.var = self.add_variable(...)` + build() will be called in the first invocation of `__call__()`, with + the same arguments passed `call()`. + * `call()`: Has all updates to variables, as in: + self.var.assign_add(...) + * `result()`: Computes and returns a final value for the metric from the variables in `self`. + + Decendants may override, but usually won't need to: + * `aggregate()`: Adds in the state from a list of metrics of the same type + as `self`. (Default is to sum all the variables.) + * `reset()`: Reset all variables to their initial state. (Default is to + zero all the variables.) + Note that users should not call `aggregate()` or `reset()`, they are for + use by TensorFlow infrastructure. """ def __init__(self, name=None): - self.built = False + self._built = False self._vars = [] self._updates = [] - self._name = name or self.__class__.__name__ - # TODO(josh11b): Need some way to make sure two Metrics in the same - # Network have distinct names. Maybe we can get a unique name from - # a name/variable scope? - # TODO(josh11b): self._in_graph_mode = context.in_graph_mode() + name = name or self.__class__.__name__ + # Replace things like spaces in name to create a valid scope name. + scope_name = _to_replace.sub("_", name) + # We create the variable scope now to get the unique name that will + # be used as a variable prefix when build() calls add_variable(). + with variable_scope.variable_scope( + None, default_name=scope_name, use_resource=True, reuse=False) as scope: + pos = scope.name.rfind(scope_name) + self._name = name + scope.name[pos + len(scope_name):] + self._scope = scope + if context.in_graph_mode(): + # We make self.call() into a graph callable here, so that we can + # return a single op that performs all of the variable updates. + self.call = function.defun(self.call) # ---- API for users ---- def __call__(self, *args, **kwargs): - # TODO(josh11b): If self._in_graph_mode is true, make self.call() into a - # graph callable here, so that variable updates happen without requiring - # a separate fetch. - # TODO(josh11b): Do we need a separate build() method to separate - # initialization from each update? If so, how do we get the arguments - # to it? We *could* just pass in *args and **kwargs... - if not self.built: - # TODO(ashankar): Set up container isolation so there is no chance - # distinct metrics objects accidentally share variables. - # TODO(josh11b): Replace things like spaces in self._name to create - # a valid scope name. - with variable_scope.variable_scope( - self._name, use_resource=True, reuse=False): - ret = self.call(*args, **kwargs) - self.built = True - else: - ret = self.call(*args, **kwargs) - return ret + """Returns op to execute to update this metric for these inputs. + + Returns None if eager execution is enabled. + + Args: + *args: + **kwargs: A mini-batch of inputs to the Metric, passed on to `call()`. + """ + if not self._built: + with variable_scope.variable_scope(self._scope): + self.build(*args, **kwargs) + self._built = True + return self.call(*args, **kwargs) @property def name(self): @@ -84,10 +102,43 @@ class Metric(object): return self._vars # ---- To be implemented by descendants --- + def build(self, *args, **kwargs): + """Method to create variables. + + Called by `__call__()` before `call()` for the first time. + + Args: + *args: + **kwargs: The arguments to the first invocation of `__call__()`. + `build()` may use the shape and/or dtype of these arguments + when deciding how to create variables. + """ + raise NotImplementedError("Metrics must define a build() member function") + def call(self, *args, **kwargs): - """Accumulates statistics for the metric.""" + """Accumulates statistics for the metric. Users should use __call__ instead. + + Note: This function is executed as a graph function in graph mode. + This means: + a) Operations on the same resource are executed in textual order. + This should make it easier to do things like add the updated + value of a variable to another, for example. + b) You don't need to worry about collecting the update ops to execute. + All update ops added to the graph by this function will be executed. + As a result, code should generally work the same way with graph or + eager execution. + + Args: + *args: + **kwargs: A mini-batch of inputs to the Metric, as passed to + `__call__()`. + """ raise NotImplementedError("Metrics must define a call() member function") + def result(self): # TODO(josh11b): Add an optional summary_writer parameter. + """Computes and returns a final value for the metric.""" + raise NotImplementedError("Metrics must define a result() member function") + # We can support two different strategies of for doing data-parallel # distributed metric computations: # * Put metric variables on the first device and rely on small @@ -123,16 +174,19 @@ class Metric(object): self._vars[i].assign_add(math_ops.add_n([m._vars[i] for m in metrics])) # pylint: enable=protected-access - def result(self): # TODO(josh11b): Add an optional summary_writer parameter. - """Computes and returns a final value for the metric.""" - raise NotImplementedError("Metrics must define a result() member function") + def reset(self): + """Reset this metric to a freshly initialized state. + + Default implementation zeros all the metric variables. + """ + for v in self._vars: + v.assign(math_ops.zeros_like(v)) # ---- For use by descendants --- def add_variable(self, name, shape=None, dtype=None, initializer=None): """***Only for use by descendants of Metric***.""" - if self.built: - raise RuntimeError("Can't call add_variable() after a Metric has been " - "built in the first call().") + if self._built: + raise RuntimeError("Can't call add_variable() except in build().") v = variable_scope.get_variable(name, shape, dtype, initializer, trainable=False, use_resource=True) self._vars.append(v) @@ -144,6 +198,15 @@ class Mean(Metric): # TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64? # Or defaults to type of the input if it is tf.float32, else tf.float64? + def build(self, values, weights=None): + del values, weights # build() does not use call's arguments + self.numer = self.add_variable(name="numer", shape=(), + dtype=dtypes.float64, + initializer=init_ops.zeros_initializer) + self.denom = self.add_variable(name="denom", shape=(), + dtype=dtypes.float64, + initializer=init_ops.zeros_initializer) + def call(self, values, weights=None): """Accumulate statistics for computing the mean. @@ -154,13 +217,6 @@ class Mean(Metric): values: Tensor with the per-example value. weights: Optional weighting of each example. Defaults to 1. """ - if not self.built: # False only in the first call(). - self.numer = self.add_variable(name="numer", shape=(), - dtype=dtypes.float64, - initializer=init_ops.zeros_initializer) - self.denom = self.add_variable(name="denom", shape=(), - dtype=dtypes.float64, - initializer=init_ops.zeros_initializer) if weights is None: self.denom.assign_add( math_ops.cast(array_ops.size(values), dtypes.float64)) @@ -179,6 +235,10 @@ class Mean(Metric): class Accuracy(Mean): """Calculates how often `predictions` matches `labels`.""" + def build(self, labels, predictions, weights=None): + del labels, predictions, weights + super(Accuracy, self).build(None) # Arguments are unused + def call(self, labels, predictions, weights=None): """Accumulate accuracy statistics. diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 089bad5a0e3..bfb79cd72e0 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -19,7 +19,11 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.eager.python import metrics +from tensorflow.python.eager import context from tensorflow.python.eager import test +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables class MetricsTest(test.TestCase): @@ -56,6 +60,53 @@ class MetricsTest(test.TestCase): m([7], [2]) # 0 correct, weight 1 self.assertEqual(2.5/5, m.result().numpy()) + def testTwoMeans(self): + # Verify two metrics with the same class and name don't + # accidentally share state. + m1 = metrics.Mean() + m2 = metrics.Mean() + m1(0) + m2(2) + self.assertEqual(0, m1.result().numpy()) + self.assertEqual(2, m2.result().numpy()) + self.assertNotEqual(m1.name, m2.name) + + def testNamesWithSpaces(self): + # Verify two metrics with the same class and name don't + # accidentally share state. + m1 = metrics.Mean("has space") + m2 = metrics.Mean("has space") + m2(2) + m1(0) + self.assertEqual(m1.name, "has space") + self.assertEqual(m1.numer.name, "has_space/numer:0") + self.assertEqual(m2.name, "has space_1") + self.assertEqual(m2.numer.name, "has_space_1/numer:0") + + def testGraph(self): + with context.graph_mode(), self.test_session() as sess: + m = metrics.Mean() + p = array_ops.placeholder(dtypes.float32) + accumulate = m(p) + variables.global_variables_initializer().run() + sess.run(accumulate, feed_dict={p: [1, 10, 100]}) + sess.run(accumulate, feed_dict={p: 1000}) + sess.run(accumulate, feed_dict={p: [10000, 100000]}) + self.assertAllEqual(m.result().eval(), 111111.0/6) + + def testTwoMeansGraph(self): + # Verify two metrics with the same class and name don't + # accidentally share state. + with context.graph_mode(), self.test_session() as sess: + m1 = metrics.Mean() + m2 = metrics.Mean() + accumulate1 = m1(0) + accumulate2 = m2(2) + variables.global_variables_initializer().run() + sess.run([accumulate1, accumulate2]) + self.assertEqual(0, m1.result().eval()) + self.assertEqual(2, m2.result().eval()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index da49517cf94..e675ee8988f 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -79,6 +79,22 @@ def capture_tensors(captures): _scoped_captures.tensors = old +def capture_value(tensor_map, value, dtype, name): + """Capture a value from outside the function, to pass in as an extra arg.""" + captured_value = tensor_map.get(ops.tensor_id(value), None) + if captured_value is None: + captured_value = graph_placeholder( + dtype=dtype or value.dtype, shape=value.shape, name=name) + if captured_value.dtype == dtypes.resource: + captured_value._handle_data = value._handle_data # pylint: disable=protected-access + tensor_map[ops.tensor_id(value)] = (value, captured_value) + else: + captured_value = captured_value[1] + tape.record_operation("captured_value", [captured_value], [value], + lambda x: [x]) + return captured_value + + def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False): """Captures a Tensor while building a graph mode function. @@ -100,18 +116,33 @@ def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False): if tensor_map is None: # Capturing is not enabled. return constant_op.constant(value.numpy()) - captured_value = tensor_map.get(ops.tensor_id(value), None) - if captured_value is None: - captured_value = graph_placeholder( - dtype=dtype or value.dtype, shape=value.shape, name=name) - if captured_value.dtype == dtypes.resource: - captured_value._handle_data = value._handle_data # pylint: disable=protected-access - tensor_map[ops.tensor_id(value)] = (value, captured_value) - else: - captured_value = captured_value[1] - tape.record_operation("captured_value", [captured_value], [value], - lambda x: [x]) - return captured_value + return capture_value(tensor_map, value, dtype, name) + + +class CapturingGraph(ops.Graph): + + def __init__(self, captures): + super(CapturingGraph, self).__init__() + self._building_function = True + self.captures = captures + + def create_op( + self, + op_type, + inputs, + dtypes, # pylint: disable=redefined-outer-name + input_types=None, + name=None, + attrs=None, + op_def=None, + compute_shapes=True, + compute_device=True): + for i, inp in enumerate(inputs): + if inp.graph is not self: + inputs[i] = capture_value(self.captures, inp, inp.dtype, inp.op.name) + return super(CapturingGraph, self).create_op( + op_type, inputs, dtypes, input_types, name, attrs, op_def, + compute_shapes, compute_device) # TODO(apassos): it'd be really nice if we could scope this registration. @@ -325,6 +356,8 @@ class _GraphModeFunction(object): name="FunctionCall", compute_shapes=False) result = op.outputs + if not result: + return op for i, s in enumerate(self._output_shapes): result[i].set_shape(s) else: @@ -381,7 +414,8 @@ def _get_defun_inputs(args): def _defun_internal(name, func, args, kwds): """Defines and returns graph-mode version of func.""" with context.graph_mode(): - tmp_graph = ops.Graph() + captures = {} + tmp_graph = CapturingGraph(captures) # Copy the graph collections to ensure summaries and other things work. This # lets the function access (but not mutate) collections of the containing # graph, such as the global step and the summary writer collections. @@ -392,7 +426,6 @@ def _defun_internal(name, func, args, kwds): with tmp_graph.as_default(): func_inputs = _get_defun_inputs(args) - captures = {} with capture_tensors(captures): func_outputs = func(*func_inputs, **kwds) ids = list(sorted(captures.keys())) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index fb647f5c211..a4c351e8c9f 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import clip_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables class FunctionTest(test.TestCase): @@ -68,6 +69,23 @@ class FunctionTest(test.TestCase): self.assertAllEqual(step(), 2.0) + def testGraphModeCaptureVariable(self): + with context.graph_mode(), self.test_session() as sess: + + class HasAVar(object): + + def __init__(self): + self.v = resource_variable_ops.ResourceVariable(1.0) + + def call(self): + return self.v * 2 + + o = HasAVar() + variables.global_variables_initializer().run() + call = function.defun(o.call) + op = call() + self.assertAllEqual(sess.run(op), 2.0) + def testTensorConversionWithDefun(self): @function.defun diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index 4e72d025a22..1703cae1e5d 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -95,6 +95,7 @@ do_pylint() { "^tensorflow/python/platform/default/_googletest\.py.*\[E0102.*function\salready\sdefined "\ "^tensorflow/python/feature_column/feature_column_test\.py.*\[E0110.*abstract-class-instantiated "\ "^tensorflow/contrib/layers/python/layers/feature_column\.py.*\[E0110.*abstract-class-instantiated "\ +"^tensorflow/contrib/eager/python/metrics_impl\.py.*\[E0202.*method-hidden "\ "^tensorflow/python/platform/gfile\.py.*\[E0301.*non-iterator "\ "^tensorflow/python/keras/_impl/keras/callbacks\.py.*\[E1133.*not-an-iterable" From 3715cffc6e2338cf2fc6ad6aba5c1d00ce598bfd Mon Sep 17 00:00:00 2001 From: Anna R Date: Thu, 19 Oct 2017 18:27:01 -0700 Subject: [PATCH 18/41] Internal change. PiperOrigin-RevId: 172829126 --- tensorflow/core/framework/api_def.proto | 5 +- tensorflow/core/framework/op_gen_lib.cc | 18 +++++++ tensorflow/core/framework/op_gen_lib_test.cc | 50 ++++++++++++++++++-- 3 files changed, 68 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/framework/api_def.proto b/tensorflow/core/framework/api_def.proto index 987caee2506..98c38efc0e9 100644 --- a/tensorflow/core/framework/api_def.proto +++ b/tensorflow/core/framework/api_def.proto @@ -51,7 +51,8 @@ message ApiDef { // endpoints are deprecated). message Endpoint { // Name should be either like "CamelCaseName" or - // "Package.CamelCaseName". + // "Package.CamelCaseName". Client-language-specific ApiDefs may + // use a snake_case convention instead of CamelCase. string name = 1; // First GraphDef version at which the op is disallowed. @@ -74,7 +75,7 @@ message ApiDef { } repeated Arg in_arg = 4; repeated Arg out_arg = 5; - // List of post-rename in_arg names to specify new argument order. + // List of original in_arg names to specify new argument order. // Length of arg_order should be either empty to keep current order // or match size of in_arg. repeated string arg_order = 11; diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc index cfaca897ba8..1e93e9be095 100644 --- a/tensorflow/core/framework/op_gen_lib.cc +++ b/tensorflow/core/framework/op_gen_lib.cc @@ -412,6 +412,8 @@ void InitApiDefFromOpDef(const OpDef& op_def, ApiDef* api_def) { api_in_arg->set_name(op_in_arg.name()); api_in_arg->set_rename_to(op_in_arg.name()); api_in_arg->set_description(op_in_arg.description()); + + *api_def->add_arg_order() = op_in_arg.name(); } for (const auto& op_out_arg : op_def.output_arg()) { auto* api_out_arg = api_def->add_out_arg(); @@ -503,6 +505,22 @@ Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) { } // Merge arg order if (new_api_def.arg_order_size() > 0) { + // Validate that new arg_order is correct. + if (new_api_def.arg_order_size() != base_api_def->arg_order_size()) { + return errors::FailedPrecondition( + "Invalid number of arguments ", new_api_def.arg_order_size(), " for ", + base_api_def->graph_op_name(), + ". Expected: ", base_api_def->arg_order_size()); + } + if (!std::is_permutation(new_api_def.arg_order().begin(), + new_api_def.arg_order().end(), + base_api_def->arg_order().begin())) { + return errors::FailedPrecondition( + "Invalid arg_order: ", str_util::Join(new_api_def.arg_order(), ", "), + " for ", base_api_def->graph_op_name(), + ". All elements in arg_order override must match base arg_order: ", + str_util::Join(base_api_def->arg_order(), ", ")); + } base_api_def->clear_arg_order(); std::copy( new_api_def.arg_order().begin(), new_api_def.arg_order().end(), diff --git a/tensorflow/core/framework/op_gen_lib_test.cc b/tensorflow/core/framework/op_gen_lib_test.cc index b7ee6db9912..da9b4dfbb17 100644 --- a/tensorflow/core/framework/op_gen_lib_test.cc +++ b/tensorflow/core/framework/op_gen_lib_test.cc @@ -207,6 +207,8 @@ attr { name: "attr_a" rename_to: "attr_a" } +arg_order: "arg_a" +arg_order: "arg_b" )"; OpList op_list; protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT @@ -331,8 +333,8 @@ op { name: "arg_c" rename_to: "arg_cc" } - arg_order: "arg_aa" arg_order: "arg_b" + arg_order: "arg_a" } )"; OpList op_list; @@ -351,8 +353,8 @@ op { EXPECT_EQ("arg_cc", api_def->out_arg(0).rename_to()); ASSERT_EQ(2, api_def->arg_order_size()); - EXPECT_EQ("arg_aa", api_def->arg_order(0)); - EXPECT_EQ("arg_b", api_def->arg_order(1)); + EXPECT_EQ("arg_b", api_def->arg_order(0)); + EXPECT_EQ("arg_a", api_def->arg_order(1)); } TEST(OpGenLibTest, ApiDefOverrideDescriptions) { @@ -411,5 +413,47 @@ op { auto status = api_map.LoadApiDef(api_def1); ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code()); } + +TEST(OpGenLibTest, ApiDefInvalidArgOrder) { + const string api_def1 = R"( +op { + graph_op_name: "testop" + arg_order: "arg_a" + arg_order: "unexpected_arg" +} +)"; + + const string api_def2 = R"( +op { + graph_op_name: "testop" + arg_order: "arg_a" +} +)"; + + const string api_def3 = R"( +op { + graph_op_name: "testop" + arg_order: "arg_a" + arg_order: "arg_a" +} +)"; + + OpList op_list; + protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT + ApiDefMap api_map(op_list); + TF_CHECK_OK(api_map.LoadApiDef(kTestApiDef)); + + // Loading with incorrect arg name in arg_order should fail. + auto status = api_map.LoadApiDef(api_def1); + ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code()); + + // Loading with incorrect number of args in arg_order should fail. + status = api_map.LoadApiDef(api_def2); + ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code()); + + // Loading with the same argument twice in arg_order should fail. + status = api_map.LoadApiDef(api_def3); + ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code()); +} } // namespace } // namespace tensorflow From 0671c0b2546dbea87e231d336d5f4c0573a01964 Mon Sep 17 00:00:00 2001 From: David Soergel Date: Thu, 19 Oct 2017 19:01:19 -0700 Subject: [PATCH 19/41] Usability improvements regarding export signature generation. * Log report of which signatures are produced and which TF Serving APIs are targeted. * Improve docstrings for signature_def builders, explaining the TF Serving API constraints. * Accept a single Tensor as a prediction output (which will be named 'output'). PiperOrigin-RevId: 172831366 --- tensorflow/python/estimator/export/export.py | 56 +++++++++++++++++-- .../python/estimator/export/export_output.py | 10 ++-- .../estimator/export/export_output_test.py | 14 ++--- .../saved_model/signature_def_utils_impl.py | 23 ++++++-- 4 files changed, 80 insertions(+), 23 deletions(-) diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py index e2e20f0d717..31e9933c6f7 100644 --- a/tensorflow/python/estimator/export/export.py +++ b/tensorflow/python/estimator/export/export.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.util import compat @@ -47,8 +48,8 @@ class ServingInputReceiver(collections.namedtuple( """A return type for a serving_input_receiver_fn. The expected return values are: - features: A dict of string to `Tensor` or `SparseTensor`, specifying the - features to be passed to the model. + features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or + `SparseTensor`, specifying the features to be passed to the model. receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying input nodes where this receiver expects to be fed by default. Typically, this is a single placeholder expecting serialized `tf.Example` protos. @@ -193,13 +194,14 @@ def build_all_signature_defs(receiver_tensors, raise ValueError('export_outputs must be a dict.') signature_def_map = {} + excluded_signatures = {} for output_key, export_output in export_outputs.items(): signature_name = '{}'.format(output_key or 'None') try: signature = export_output.as_signature_def(receiver_tensors) signature_def_map[signature_name] = signature - except ValueError: - pass + except ValueError as e: + excluded_signatures[signature_name] = str(e) if receiver_tensors_alternatives: for receiver_name, receiver_tensors_alt in ( @@ -213,8 +215,10 @@ def build_all_signature_defs(receiver_tensors, try: signature = export_output.as_signature_def(receiver_tensors_alt) signature_def_map[signature_name] = signature - except ValueError: - pass + except ValueError as e: + excluded_signatures[signature_name] = str(e) + + _log_signature_report(signature_def_map, excluded_signatures) # The above calls to export_output.as_signature_def should return only # valid signatures; if there is a validity problem, they raise ValueError, @@ -224,6 +228,46 @@ def build_all_signature_defs(receiver_tensors, if signature_def_utils.is_valid_signature(v)} +_FRIENDLY_METHOD_NAMES = { + signature_constants.CLASSIFY_METHOD_NAME: 'Classify', + signature_constants.REGRESS_METHOD_NAME: 'Regress', + signature_constants.PREDICT_METHOD_NAME: 'Predict', +} + + +def _log_signature_report(signature_def_map, excluded_signatures): + """Log a report of which signatures were produced.""" + sig_names_by_method_name = collections.defaultdict(list) + + # We'll collect whatever method_names are present, but also we want to make + # sure to output a line for each of the three standard methods even if they + # have no signatures. + for method_name in _FRIENDLY_METHOD_NAMES: + sig_names_by_method_name[method_name] = [] + + for signature_name, sig in signature_def_map.items(): + sig_names_by_method_name[sig.method_name].append(signature_name) + + # TODO(b/67733540): consider printing the full signatures, not just names + for method_name, sig_names in sig_names_by_method_name.items(): + if method_name in _FRIENDLY_METHOD_NAMES: + method_name = _FRIENDLY_METHOD_NAMES[method_name] + logging.info('Signatures INCLUDED in export for {}: {}'.format( + method_name, sig_names if sig_names else 'None')) + + if excluded_signatures: + logging.info('Signatures EXCLUDED from export because they cannot be ' + 'be served via TensorFlow Serving APIs:') + for signature_name, message in excluded_signatures.items(): + logging.info('\'{}\' : {}'.format(signature_name, message)) + + if not signature_def_map: + logging.warn('Export includes no signatures!') + elif (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + not in signature_def_map): + logging.warn('Export includes no default signature!') + + # When we create a timestamped directory, there is a small chance that the # directory already exists because another worker is also writing exports. # In this case we just wait one second to get a new timestamp and try again. diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py index 7c7f92872eb..863af6d41d9 100644 --- a/tensorflow/python/estimator/export/export_output.py +++ b/tensorflow/python/estimator/export/export_output.py @@ -150,6 +150,9 @@ class RegressionOutput(ExportOutput): return signature_def_utils.regression_signature_def(examples, self.value) +_SINGLE_OUTPUT_DEFAULT_NAME = 'output' + + class PredictOutput(ExportOutput): """Represents the output of a generic prediction head. @@ -162,16 +165,15 @@ class PredictOutput(ExportOutput): """Constructor for PredictOutput. Args: - outputs: A dict of string to `Tensor` representing the predictions. + outputs: A `Tensor` or a dict of string to `Tensor` representing the + predictions. Raises: ValueError: if the outputs is not dict, or any of its keys are not strings, or any of its values are not `Tensor`s. """ if not isinstance(outputs, dict): - raise ValueError( - 'Prediction outputs must be given as a dict of string to Tensor; ' - 'got {}'.format(outputs)) + outputs = {_SINGLE_OUTPUT_DEFAULT_NAME: outputs} for key, value in outputs.items(): if not isinstance(key, six.string_types): raise ValueError( diff --git a/tensorflow/python/estimator/export/export_output_test.py b/tensorflow/python/estimator/export/export_output_test.py index 035a9a143e6..7090e53d807 100644 --- a/tensorflow/python/estimator/export/export_output_test.py +++ b/tensorflow/python/estimator/export/export_output_test.py @@ -199,20 +199,18 @@ class ExportOutputTest(test.TestCase): signature_constants.CLASSIFY_METHOD_NAME) self.assertEqual(actual_signature_def, expected_signature_def) - def test_predict_output_constructor(self): - """Tests that no errors are raised when input is expected.""" + def test_predict_outputs_valid(self): + """Tests that no errors are raised when provided outputs are valid.""" outputs = { "output0": constant_op.constant([0]), - u"output1": constant_op.constant([1]), + u"output1": constant_op.constant(["foo"]), } export_output_lib.PredictOutput(outputs) - def test_predict_output_outputs_invalid(self): - with self.assertRaisesRegexp( - ValueError, - "Prediction outputs must be given as a dict of string to Tensor"): - export_output_lib.PredictOutput(constant_op.constant([0])) + # Single Tensor is OK too + export_output_lib.PredictOutput(constant_op.constant([0])) + def test_predict_outputs_invalid(self): with self.assertRaisesRegexp( ValueError, "Prediction output key must be a string"): diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py index 564befeb0b5..240ea61aa5f 100644 --- a/tensorflow/python/saved_model/signature_def_utils_impl.py +++ b/tensorflow/python/saved_model/signature_def_utils_impl.py @@ -56,9 +56,13 @@ def build_signature_def(inputs=None, outputs=None, method_name=None): def regression_signature_def(examples, predictions): """Creates regression signature from given examples and predictions. + This function produces signatures intended for use with the TensorFlow Serving + Regress API (tensorflow_serving/apis/prediction_service.proto), and so + constrains the input and output types to those allowed by TensorFlow Serving. + Args: - examples: `Tensor`. - predictions: `Tensor`. + examples: A string `Tensor`, expected to accept serialized tf.Examples. + predictions: A float `Tensor`. Returns: A regression-flavored signature_def. @@ -93,10 +97,15 @@ def regression_signature_def(examples, predictions): def classification_signature_def(examples, classes, scores): """Creates classification signature from given examples and predictions. + This function produces signatures intended for use with the TensorFlow Serving + Classify API (tensorflow_serving/apis/prediction_service.proto), and so + constrains the input and output types to those allowed by TensorFlow Serving. + Args: - examples: `Tensor`. - classes: `Tensor`. - scores: `Tensor`. + examples: A string `Tensor`, expected to accept serialized tf.Examples. + classes: A string `Tensor`. Note that the ClassificationResponse message + requires that class labels are strings, not integers or anything else. + scores: a float `Tensor`. Returns: A classification-flavored signature_def. @@ -140,6 +149,10 @@ def classification_signature_def(examples, classes, scores): def predict_signature_def(inputs, outputs): """Creates prediction signature from given inputs and outputs. + This function produces signatures intended for use with the TensorFlow Serving + Predict API (tensorflow_serving/apis/prediction_service.proto). This API + imposes no constraints on the input and output types. + Args: inputs: dict of string to `Tensor`. outputs: dict of string to `Tensor`. From 58121b8b13597d3285f121f02bd2a512bc76be17 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Oct 2017 20:09:38 -0700 Subject: [PATCH 20/41] Pull out a non-test-only class HloRunnerBase from HloTestBase so that it can be used as a library for running HloModule on given platform. Also add a function to read HloModule from a HloProto file, and a function to make fake input literals for given HloModule. PiperOrigin-RevId: 172835863 --- tensorflow/compiler/xla/service/BUILD | 23 ++ tensorflow/compiler/xla/service/hlo_runner.cc | 199 ++++++++++++++++++ tensorflow/compiler/xla/service/hlo_runner.h | 100 +++++++++ tensorflow/compiler/xla/tests/BUILD | 12 +- .../compiler/xla/tests/hlo_test_base.cc | 114 +--------- tensorflow/compiler/xla/tests/hlo_test_base.h | 26 +-- 6 files changed, 335 insertions(+), 139 deletions(-) create mode 100644 tensorflow/compiler/xla/service/hlo_runner.cc create mode 100644 tensorflow/compiler/xla/service/hlo_runner.h diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 1ef329365ea..8f5105aa530 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2066,6 +2066,29 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_runner", + srcs = ["hlo_runner.cc"], + hdrs = ["hlo_runner.h"], + deps = [ + ":executable", + ":hlo", + ":transfer_manager", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:backend", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//third_party/eigen3", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc new file mode 100644 index 00000000000..d5d7042a02b --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -0,0 +1,199 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_runner.h" + +#include +#include +#include + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/common_runtime/eigen_thread_pool.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +/*static*/ StatusOr> +HloRunner::ReadModuleFromHloProtoFile(const char* filename, + const DebugOptions& debug_options) { + HloProto proto; + TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), + filename, &proto)); + HloModuleConfig config; + config.set_debug_options(debug_options); + TF_ASSIGN_OR_RETURN(auto module, HloModule::CreateFromProto( + proto.hlo_module(), + VersionedComputationHandle(), config)); + return std::move(module); +} + +// Define this in .cc file to avoid having to include eigen or forward declare +// these types in the header. +struct HloRunner::EigenThreadPoolWrapper { + std::unique_ptr pool; + std::unique_ptr device; +}; + +HloRunner::HloRunner() {} + +HloRunner::HloRunner(se::Platform* platform) { + BackendOptions backend_options; + backend_options.set_platform(platform); + backend_ = Backend::CreateBackend(backend_options).ConsumeValueOrDie(); + VLOG(1) << "Created HloRunner for platform: " << platform->Name(); +} + +HloRunner::~HloRunner() { + // Deallocate all the memory allocated during the tests. + for (auto& allocation : allocations_) { + backend().default_stream_executor()->Deallocate(&allocation); + } +} + +StatusOr HloRunner::Execute( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments, + Shape* result_shape) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + backend().compiler()->Compile(std::move(module), + backend().default_stream_executor())); + + se::Stream stream(backend().default_stream_executor()); + stream.Init(); + + ExecutableRunOptions run_options; + run_options.set_stream(&stream); + run_options.set_allocator(backend().memory_allocator()); + run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool()); + run_options.set_intra_op_thread_pool( + backend().eigen_intra_op_thread_pool_device()); + + HloExecutionProfile hlo_execution_profile; + ServiceExecutableRunOptions service_run_options( + run_options, backend().StreamBorrower(), + backend().inter_op_thread_pool()); + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase result, + executable->ExecuteOnStream(&service_run_options, arguments, + &hlo_execution_profile)); + TF_RET_CHECK(stream.BlockHostUntilDone()); + + allocations_.push_back(result); + + *result_shape = executable->result_shape(); + + if (ShapeUtil::IsTuple(*result_shape)) { + // We must record element buffers of tuples as well to avoid leaks. + DCHECK(!ShapeUtil::IsNestedTuple(*result_shape)); + TF_ASSIGN_OR_RETURN( + std::vector element_buffers, + backend().transfer_manager()->ShallowCopyTupleFromDevice( + backend().default_stream_executor(), result, *result_shape)); + + // A tuple may contain the same buffer in more than one element. Keep track + // of the buffers already added to avoid duplicates in allocations_. + std::set added_opaques; + for (auto element_buffer : element_buffers) { + if (added_opaques.count(element_buffer.opaque()) == 0) { + CHECK(element_buffer.opaque() != nullptr); + added_opaques.insert(element_buffer.opaque()); + allocations_.push_back(element_buffer); + } + } + } + + return result; +} + +se::DeviceMemoryBase HloRunner::TransferToDevice(const Literal& literal) { + // Allocate memory on the device using the stream executor. + int64 allocation_size = + backend().transfer_manager()->GetByteSizeRequirement(literal.shape()); + se::DeviceMemoryBase allocation = + backend().default_stream_executor()->AllocateArray( + allocation_size); + allocations_.push_back(allocation); + + TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToDevice( + backend().default_stream_executor(), literal, &allocation)); + + return allocation; +} + +std::unique_ptr HloRunner::TransferFromDevice( + const Shape& shape, se::DeviceMemoryBase device_base) { + auto literal = MakeUnique(); + TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromDevice( + backend().default_stream_executor(), device_base, shape, shape, + literal.get())); + return literal; +} + +std::unique_ptr HloRunner::ExecuteAndTransfer( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments) { + Shape result_shape; + se::DeviceMemoryBase device_base = + Execute(std::move(module), arguments, &result_shape).ValueOrDie(); + return TransferFromDevice(result_shape, device_base); +} + +template <> +std::unique_ptr HloRunner::Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice>& literals) { + std::vector arguments; + for (const auto& literal : literals) { + arguments.push_back(TransferToDevice(*literal)); + } + return ExecuteAndTransfer(std::move(module), arguments); +} + +template <> +std::unique_ptr HloRunner::Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice& literals) { + std::vector arguments; + for (const auto& literal : literals) { + arguments.push_back(TransferToDevice(*literal)); + } + return ExecuteAndTransfer(std::move(module), arguments); +} + +Backend& HloRunner::backend() { + if (!backend_) { + backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie(); + VLOG(1) << "executing on platform " << backend().platform()->Name(); + } + return *backend_; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h new file mode 100644 index 00000000000..d74a1b59a8c --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -0,0 +1,100 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// A base class for running an HloModule. This executes the given HloModule on a +// certain backend directly without using the client interface. HloModule can be +// explicitly built, or loaded from a serialization file (e.g., hlo proto file). +class HloRunner { + public: + HloRunner(); + + HloRunner(::perftools::gputools::Platform* platform); + + ~HloRunner(); + + // Reads the binary proto file in xla.HloProto format, creates and returns the + // HloModule. + static StatusOr> ReadModuleFromHloProtoFile( + const char* filename, const DebugOptions& debug_options); + + // Executes the given module with given literals as input and returns the + // result as a Literal. The LiteralPtr type accepts Literal* or + // std::unique_ptr. + template + std::unique_ptr Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice& literals); + + // Executes the given module and returns a global data handle. + StatusOr Execute( + std::unique_ptr module, + tensorflow::gtl::ArraySlice + arguments, + Shape* result_shape); + + // Transfers the given literal to the device and returns the data handle. + perftools::gputools::DeviceMemoryBase TransferToDevice( + const Literal& literal); + + // Transfers the array referred to by the given handle from the device and + // returns as a Literal. + std::unique_ptr TransferFromDevice( + const Shape& shape, perftools::gputools::DeviceMemoryBase device_base); + + // Executes the given module and return the result as a Literal. + std::unique_ptr ExecuteAndTransfer( + std::unique_ptr module, + tensorflow::gtl::ArraySlice + arguments); + + // If backend is not created in the constructor, creates and returns the + // default backend. If creation fails, crashes the program. + // + // This creates the backend lazily so it's possible to instantiate an + // HloRunner in a program without any backends linked in. + Backend& backend(); + + private: + struct EigenThreadPoolWrapper; + + std::vector allocations_; + + std::unique_ptr thread_pool_wrapper_; + + std::unique_ptr backend_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index b02d906d93e..43127925e65 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -102,28 +102,18 @@ cc_library( deps = [ ":literal_test_util", "//tensorflow/compiler/xla:shape_layout", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:backend", - "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_layout", - "//tensorflow/compiler/xla/service:computation_placer", - "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_execution_profile", - "//tensorflow/compiler/xla/service:hlo_graph_dumper", - "//tensorflow/compiler/xla/service:transfer_manager", - "//tensorflow/core:core_cpu_internal", + "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", - "//third_party/eigen3", ], ) diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 26513d6ce8e..3e244fbfd9d 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -19,24 +19,9 @@ limitations under the License. #include #include -#define EIGEN_USE_THREADS - -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/backend.h" -#include "tensorflow/compiler/xla/service/computation_layout.h" -#include "tensorflow/compiler/xla/service/executable.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/transfer_manager.h" -#include "tensorflow/compiler/xla/shape_layout.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -45,22 +30,6 @@ namespace se = ::perftools::gputools; namespace xla { -// Define this in .cc file to avoid having to include eigen or forward declare -// these types in the header. -struct HloTestBase::EigenThreadPoolWrapper { - std::unique_ptr pool; - std::unique_ptr device; -}; - -HloTestBase::HloTestBase() {} - -HloTestBase::~HloTestBase() { - // Deallocate all the memory allocated during the tests. - for (auto& allocation : allocations_) { - backend().default_stream_executor()->Deallocate(&allocation); - } -} - /* static */ std::unique_ptr HloTestBase::CreateNewModule() { HloModuleConfig config; @@ -80,98 +49,25 @@ StatusOr HloTestBase::Execute( tensorflow::gtl::ArraySlice arguments, Shape* result_shape) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - backend().compiler()->Compile(std::move(module), - backend().default_stream_executor())); - - se::Stream stream(backend().default_stream_executor()); - stream.Init(); - - ExecutableRunOptions run_options; - run_options.set_stream(&stream); - run_options.set_allocator(backend().memory_allocator()); - run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool()); - run_options.set_intra_op_thread_pool( - backend().eigen_intra_op_thread_pool_device()); - - HloExecutionProfile hlo_execution_profile; - ServiceExecutableRunOptions service_run_options( - run_options, backend().StreamBorrower(), - backend().inter_op_thread_pool()); - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase result, - executable->ExecuteOnStream(&service_run_options, arguments, - &hlo_execution_profile)); - TF_RET_CHECK(stream.BlockHostUntilDone()); - - allocations_.push_back(result); - - *result_shape = executable->result_shape(); - - if (ShapeUtil::IsTuple(*result_shape)) { - // We must record element buffers of tuples as well to avoid leaks. - DCHECK(!ShapeUtil::IsNestedTuple(*result_shape)); - TF_ASSIGN_OR_RETURN( - std::vector element_buffers, - backend().transfer_manager()->ShallowCopyTupleFromDevice( - backend().default_stream_executor(), result, *result_shape)); - - // A tuple may contain the same buffer in more than one element. Keep track - // of the buffers already added to avoid duplicates in allocations_. - std::set added_opaques; - for (auto element_buffer : element_buffers) { - if (added_opaques.count(element_buffer.opaque()) == 0) { - CHECK(element_buffer.opaque() != nullptr); - added_opaques.insert(element_buffer.opaque()); - allocations_.push_back(element_buffer); - } - } - } - - return result; + return runner_.Execute(std::move(module), arguments, result_shape); } se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) { - // Allocate memory on the device using the stream executor. - int64 allocation_size = - backend().transfer_manager()->GetByteSizeRequirement(literal.shape()); - se::DeviceMemoryBase allocation = - backend().default_stream_executor()->AllocateArray( - allocation_size); - allocations_.push_back(allocation); - - TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToDevice( - backend().default_stream_executor(), literal, &allocation)); - - return allocation; + return runner_.TransferToDevice(literal); } std::unique_ptr HloTestBase::TransferFromDevice( const Shape& shape, se::DeviceMemoryBase device_base) { - auto literal = MakeUnique(); - TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromDevice( - backend().default_stream_executor(), device_base, shape, shape, - literal.get())); - return literal; + return runner_.TransferFromDevice(shape, device_base); } std::unique_ptr HloTestBase::ExecuteAndTransfer( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments) { - Shape result_shape; - se::DeviceMemoryBase device_base = - Execute(std::move(module), arguments, &result_shape).ValueOrDie(); - return TransferFromDevice(result_shape, device_base); + return runner_.ExecuteAndTransfer(std::move(module), arguments); } -Backend& HloTestBase::backend() { - if (!backend_) { - backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie(); - VLOG(1) << "executing on platform " << backend().platform()->Name(); - } - return *backend_; -} +Backend& HloTestBase::backend() { return runner_.backend(); } /* static */ string HloTestBase::TestName() { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 275f1f5c7ba..7f068dce36b 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -21,12 +21,12 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/backend.h" -#include "tensorflow/compiler/xla/service/compiler.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -39,10 +39,9 @@ namespace xla { // building a graph of HLO instructions to run. class HloTestBase : public ::testing::Test { protected: - struct EigenThreadPoolWrapper; - HloTestBase(); + HloTestBase() {} - ~HloTestBase() override; + ~HloTestBase() override {} // Creates a new HLO module for a test. The module created will have // TestName() for its name; it will also automatically populate its debug @@ -102,23 +101,12 @@ class HloTestBase : public ::testing::Test { static string TestName(); - // Creates (if necessary) and returns the default backend. If creation fails, - // crashes the program. - // - // This creates the backend lazily so it's possible to instantiate an - // HloTestBase in a program without any backends linked in. + // Returns the backend owned by the HloRunner. Backend& backend(); - // This vector contains handles of all the device memory allocations performed - // by the test. These are deallocated on destruction of the test object. - std::vector allocations_; + HloRunner runner_; ErrorSpec error_spec_{0.0001}; - - std::unique_ptr thread_pool_wrapper_; - - private: - std::unique_ptr backend_; // Lazily populated. Access via backend(). }; } // namespace xla From aa9ddb2006cba090a53ea978a6ec78bea8245805 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Oct 2017 20:10:03 -0700 Subject: [PATCH 21/41] Add a tool which reads the Hlo module proto and convert it into JSON format. PiperOrigin-RevId: 172835881 --- tensorflow/compiler/xla/tools/BUILD | 12 +++ .../compiler/xla/tools/hlo_proto_to_json.cc | 91 +++++++++++++++++++ 2 files changed, 103 insertions(+) create mode 100644 tensorflow/compiler/xla/tools/hlo_proto_to_json.cc diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 0451537af77..759921dce5a 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -210,6 +210,18 @@ tf_cc_binary( ], ) +tf_cc_binary( + name = "hlo_proto_to_json", + srcs = ["hlo_proto_to_json.cc"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc b/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc new file mode 100644 index 00000000000..4e02e17db65 --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Usage: +// hlo_proto_to_json --input_file=some_binary_proto +// --output_file=path_to_dump_output +// +// Reads one serilized Hlo module, convert it into JSON format and dump into +// some output directory. some_binaray_proto is obtained by serializing Hlo +// module to disk using --xla_dump_hlo_proto_to debug optoin. + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" + +using tensorflow::Env; +using xla::string; + +namespace xla { +namespace tools { + +StatusOr ToJson(const tensorflow::protobuf::Message& message) { + string json_output; + tensorflow::protobuf::util::JsonPrintOptions json_options; + json_options.add_whitespace = true; + json_options.always_print_primitive_fields = true; + auto status = tensorflow::protobuf::util::MessageToJsonString( + message, &json_output, json_options); + if (!status.ok()) { + return InternalError("MessageToJsonString failed: %s", + status.error_message().data()); + } + return json_output; +} + +void RealMain(const string& input, const string& output) { + HloProto hlo_proto; + TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), input, + &hlo_proto)) + << "Can't open, read, or parse input file " << input; + + auto statusor = ToJson(hlo_proto); + QCHECK(statusor.ok()) << "Error converting " << input << " to JSON." + << statusor.status(); + + TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(), output, + statusor.ValueOrDie())); +} + +} // namespace tools +} // namespace xla + +int main(int argc, char** argv) { + string input_file, output_file; + const std::vector flag_list = { + tensorflow::Flag("input_file", &input_file, "file to convert."), + tensorflow::Flag("output_file", &output_file, "converted file"), + }; + const string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + tensorflow::port::InitMain(usage.c_str(), &argc, &argv); + QCHECK(parse_ok && argc == 1) << "\n" << usage; + + QCHECK(!input_file.empty()) << "--input_file is required"; + QCHECK(!output_file.empty()) << "--output_file is required"; + + xla::tools::RealMain(input_file, output_file); + + return 0; +} From 492ddb55a9b31a07026b7d82a2f9bcac29f4ee65 Mon Sep 17 00:00:00 2001 From: Suharsh Sivakumar Date: Thu, 19 Oct 2017 21:09:44 -0700 Subject: [PATCH 22/41] Add support for fused batch norm to fake quantize rewriter. PiperOrigin-RevId: 172839124 --- tensorflow/contrib/quantize/BUILD | 33 +- .../quantize/python/copy_graph_test.py | 2 +- .../quantize/python/fold_batch_norms.py | 269 ++++++++++++- .../quantize/python/fold_batch_norms_test.py | 368 ++++++------------ .../contrib/quantize/python/graph_matcher.py | 200 ++++++++++ .../quantize/python/graph_matcher_test.py | 130 +++++++ .../python/quantize_parameterized_test.py | 212 +++++----- 7 files changed, 853 insertions(+), 361 deletions(-) create mode 100644 tensorflow/contrib/quantize/python/graph_matcher.py create mode 100644 tensorflow/contrib/quantize/python/graph_matcher_test.py diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index 7ff186bc2ad..0d6c71965cb 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -13,6 +13,34 @@ py_library( deps = [], ) +py_library( + name = "graph_matcher", + srcs = [ + "python/graph_matcher.py", + ], + srcs_version = "PY2AND3", + deps = [], +) + +py_test( + name = "graph_matcher_test", + size = "small", + srcs = ["python/graph_matcher_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":graph_matcher", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + ], +) + py_library( name = "input_to_ops", srcs = ["python/input_to_ops.py"], @@ -43,6 +71,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":common", + ":graph_matcher", ":input_to_ops", "//tensorflow/contrib/graph_editor:graph_editor_py", "//tensorflow/python:array_ops", @@ -58,6 +87,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":fold_batch_norms", + ":graph_matcher", "//tensorflow/contrib/layers:layers_py", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", @@ -147,10 +177,11 @@ py_test( py_test( name = "quantize_parameterized_test", - size = "medium", + size = "large", srcs = ["python/quantize_parameterized_test.py"], srcs_version = "PY2AND3", deps = [ + ":fold_batch_norms", ":quantize", "//tensorflow/contrib/layers:layers_py", "//tensorflow/python:array_ops", diff --git a/tensorflow/contrib/quantize/python/copy_graph_test.py b/tensorflow/contrib/quantize/python/copy_graph_test.py index 0889f12de6a..7ff9ad9f841 100644 --- a/tensorflow/contrib/quantize/python/copy_graph_test.py +++ b/tensorflow/contrib/quantize/python/copy_graph_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for tensorflow.quantized.mangle.copy_graph.""" +"""Tests for copy_graph.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index c4166895108..647d4044001 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -21,7 +21,9 @@ from __future__ import print_function import re from tensorflow.contrib import graph_editor from tensorflow.contrib.quantize.python import common +from tensorflow.contrib.quantize.python import graph_matcher from tensorflow.contrib.quantize.python import input_to_ops +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn @@ -29,7 +31,7 @@ from tensorflow.python.ops import nn_ops def FoldBatchNorms(graph): - """Finds batch norm layers in the graph, folds them into preceding layers. + """Finds batch norm layers and folds them into preceding layers. Folding only affects the following layers: Conv2D, fully connected, depthwise convolution. @@ -40,10 +42,269 @@ def FoldBatchNorms(graph): Raises: ValueError: When batch norm folding fails. """ - # Fail immediately when the graph contains unsupported fused batch norm ops. - if any(op for op in graph.get_operations() if op.type == 'FusedBatchNorm'): - raise ValueError('Fused batch norm is not supported') + _FoldFusedBatchNorms(graph) + _FoldUnfusedBatchNorms(graph) + +def _FoldFusedBatchNorms(graph): + """Finds fused batch norm layers and folds them into preceding layers. + + Folding only affects the following layers: Conv2D, fully connected, depthwise + convolution. + + Args: + graph: Graph to walk and modify. + + Raises: + ValueError: When batch norm folding fails. + """ + for match in _FindFusedBatchNorms(graph): + scope, sep, _ = match.layer_op.name.rpartition('/') + # Make sure new ops are added to `graph` and put on the same device as + # `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope + # named `scope`. Otherwise, TF creates a unique scope whose name starts with + # `scope`. + with graph.as_default(), graph.name_scope(scope + sep), ops.device( + match.bn_op.device): + # new weights = old weights * gamma / sqrt(variance + epsilon) + # new biases = -mean * gamma / sqrt(variance + epsilon) + beta + multiplier_tensor = match.gamma_tensor * math_ops.rsqrt( + match.variance_tensor + match.bn_op.get_attr('epsilon')) + bias_tensor = math_ops.subtract( + match.beta_tensor, match.mean_tensor * multiplier_tensor, name='bias') + + # The shape of depthwise weights is different, so we need to reshape the + # multiplier_tensor to ensure that the scaled_weight_tensor has the + # expected shape. + if match.layer_op.type == 'DepthwiseConv2dNative': + new_shape = [ + match.weight_tensor.get_shape().as_list()[2], + match.weight_tensor.get_shape().as_list()[3] + ] + multiplier_tensor = array_ops.reshape( + multiplier_tensor, new_shape, name='scale_reshape') + + # TODO(suharshs): This naming of the following ops needs to carefully + # follow the naming expected by quantize.py. Generalize the quantize code + # to not require these delicate naming conventions. + scaled_weight_tensor = math_ops.multiply( + match.weight_tensor, multiplier_tensor, name='mul_fold') + + new_layer_tensor = _CloneWithNewOperands( + match.layer_op, match.input_tensor, scaled_weight_tensor) + + bias_add_tensor = math_ops.add( + new_layer_tensor, bias_tensor, name='add_fold') + + nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor, + match.output_tensor) + if nodes_modified_count != 1: + raise ValueError( + 'Unexpected inputs to op: %s' % match.output_tensor.name) + + +def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor): + """Clones layer_op with input_tensor and weight_tensor as new inputs.""" + new_layer_name = layer_op.name.split('/')[-1] + '_Fold' + if layer_op.type == 'Conv2D': + return nn_ops.conv2d( + input_tensor, + weight_tensor, + strides=layer_op.get_attr('strides'), + padding=layer_op.get_attr('padding'), + use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'), + data_format=layer_op.get_attr('data_format'), + name=new_layer_name) + elif layer_op.type == 'MatMul': + return math_ops.matmul( + input_tensor, + weight_tensor, + transpose_a=layer_op.get_attr('transpose_a'), + transpose_b=layer_op.get_attr('transpose_b'), + name=new_layer_name) + elif layer_op.type == 'DepthwiseConv2dNative': + return nn.depthwise_conv2d( + input_tensor, + weight_tensor, + strides=layer_op.get_attr('strides'), + padding=layer_op.get_attr('padding'), + name=new_layer_name) + else: + raise ValueError('Cannot handle operation of type: %s' % layer_op.type) + + +def _FindFusedBatchNorms(graph): + """Finds all ops and tensors related to found FusedBatchNorms. + + Args: + graph: Graph to inspect. + + Yields: + _FusedBatchNormMatches. + """ + input_pattern = graph_matcher.OpTypePattern('*') + weight_pattern = graph_matcher.OpTypePattern('*') + gamma_pattern = graph_matcher.OpTypePattern('*') + beta_pattern = graph_matcher.OpTypePattern('*') + mean_pattern = graph_matcher.OpTypePattern('*') + variance_pattern = graph_matcher.OpTypePattern('*') + + conv_pattern = graph_matcher.OpTypePattern( + 'Conv2D|DepthwiseConv2dNative', inputs=[input_pattern, weight_pattern]) + # MatMul has a Reshape between it and FusedBatchNorm. + matmul_pattern = graph_matcher.OpTypePattern( + 'MatMul', inputs=[input_pattern, weight_pattern]) + matmul_reshape_pattern = graph_matcher.OpTypePattern( + 'Reshape', inputs=[matmul_pattern, + graph_matcher.OpTypePattern('*')]) + + conv_batch_norm_pattern = graph_matcher.OpTypePattern( + 'FusedBatchNorm', + inputs=[ + conv_pattern, gamma_pattern, beta_pattern, mean_pattern, + variance_pattern + ]) + matmul_batch_norm_pattern = graph_matcher.OpTypePattern( + 'FusedBatchNorm', + inputs=[ + matmul_reshape_pattern, gamma_pattern, beta_pattern, mean_pattern, + variance_pattern + ]) + matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern( + 'Reshape', + inputs=[matmul_batch_norm_pattern, + graph_matcher.OpTypePattern('*')]) + + conv_matcher = graph_matcher.GraphMatcher(conv_batch_norm_pattern) + matmul_matcher = graph_matcher.GraphMatcher(matmul_bn_output_reshape_pattern) + + def _GetCommonTensors(match_result): + """Gets tensors needed for FusedBatchNormMatch from match_result.""" + input_tensor = match_result.get_tensor(input_pattern) + weight_tensor = match_result.get_tensor(weight_pattern) + gamma_tensor = match_result.get_tensor(gamma_pattern) + beta_tensor = match_result.get_tensor(beta_pattern) + # FusedBatchNorm in training is different from that in inference. It takes + # empty 'mean' and empty 'variance', and produces the mean and the variance + # of the batch. Therefore, when is_training is true, mean_tensor and + # variance_tensor point to 1st and 2nd (0-based) output of bn_op, + # respectively; when is_training is false, they point to bn_op's inputs. + is_training = bn_op.get_attr('is_training') + if is_training: + mean_tensor = bn_op.outputs[1] + variance_tensor = bn_op.outputs[2] + else: + mean_tensor = match_result.get_tensor(mean_pattern) + variance_tensor = match_result.get_tensor(variance_pattern) + return (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, + variance_tensor) + + for match_result in conv_matcher.match_graph(graph): + layer_op = match_result.get_op(conv_pattern) + bn_op = match_result.get_op(conv_batch_norm_pattern) + # In the case of convolution the output_tensor is the output of bn_op. + output_tensor = bn_op.outputs[0] + + (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, + variance_tensor) = _GetCommonTensors(match_result) + yield _FusedBatchNormMatch( + layer_op=layer_op, + bn_op=bn_op, + output_tensor=output_tensor, + input_tensor=input_tensor, + weight_tensor=weight_tensor, + gamma_tensor=gamma_tensor, + beta_tensor=beta_tensor, + mean_tensor=mean_tensor, + variance_tensor=variance_tensor) + + for match_result in matmul_matcher.match_graph(graph): + layer_op = match_result.get_op(matmul_pattern) + bn_op = match_result.get_op(matmul_batch_norm_pattern) + # In the MatMul case, the output of batch norm is reshaped back into a + # 2D tensor, so the output_tensor is the output of the Reshape op. + output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern) + output_tensor = output_reshape_op.outputs[0] + + (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, + variance_tensor) = _GetCommonTensors(match_result) + yield _FusedBatchNormMatch( + layer_op=layer_op, + bn_op=bn_op, + output_tensor=output_tensor, + input_tensor=input_tensor, + weight_tensor=weight_tensor, + gamma_tensor=gamma_tensor, + beta_tensor=beta_tensor, + mean_tensor=mean_tensor, + variance_tensor=variance_tensor) + + +class _FusedBatchNormMatch(object): + """Contains all information related to a found FusedBatchNorm.""" + + def __init__(self, layer_op, bn_op, output_tensor, input_tensor, + weight_tensor, gamma_tensor, beta_tensor, mean_tensor, + variance_tensor): + self._layer_op = layer_op + self._bn_op = bn_op + self._output_tensor = output_tensor + self._input_tensor = input_tensor + self._weight_tensor = weight_tensor + self._gamma_tensor = gamma_tensor + self._beta_tensor = beta_tensor + self._mean_tensor = mean_tensor + self._variance_tensor = variance_tensor + + @property + def layer_op(self): + return self._layer_op + + @property + def bn_op(self): + return self._bn_op + + @property + def output_tensor(self): + return self._output_tensor + + @property + def input_tensor(self): + return self._input_tensor + + @property + def weight_tensor(self): + return self._weight_tensor + + @property + def gamma_tensor(self): + return self._gamma_tensor + + @property + def beta_tensor(self): + return self._beta_tensor + + @property + def mean_tensor(self): + return self._mean_tensor + + @property + def variance_tensor(self): + return self._variance_tensor + + +def _FoldUnfusedBatchNorms(graph): + """Finds unfused batch norm layers and folds them into preceding layers. + + Folding only affects the following layers: Conv2D, fully connected, depthwise + convolution. + + Args: + graph: Graph to walk and modify. + + Raises: + ValueError: When batch norm folding fails. + """ input_to_ops_map = input_to_ops.InputToOps(graph) for bn in common.BatchNormGroups(graph): diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py index ddedb0a2c06..5a66b38b155 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.quantize.python import fold_batch_norms from tensorflow.python.framework import dtypes @@ -35,57 +34,32 @@ conv2d = layers.conv2d fully_connected = layers.fully_connected separable_conv2d = layers.separable_conv2d -_DEFAULT_BATCH_NORM_PARAMS = { - 'center': True, - 'scale': True, - 'decay': 1.0 - 0.003, - 'fused': False, -} - # TODO(suharshs): Use parameterized test once OSS TF supports it. class FoldBatchNormsTest(test_util.TensorFlowTestCase): def _RunTestOverParameters(self, test_fn): parameters_list = [ - # (relu, relu_op_name, with_bypass) - (nn_ops.relu6, 'Relu6', False), - (nn_ops.relu, 'Relu', False), - (nn_ops.relu6, 'Relu6', True), - (nn_ops.relu, 'Relu', True), + # (relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm) + (nn_ops.relu6, 'Relu6', False, False, False), + (nn_ops.relu, 'Relu', False, False, False), + (nn_ops.relu6, 'Relu6', True, False, False), + (nn_ops.relu, 'Relu', True, False, False), + (nn_ops.relu6, 'Relu6', False, True, False), + (nn_ops.relu, 'Relu', False, True, False), + (nn_ops.relu6, 'Relu6', True, True, False), + (nn_ops.relu, 'Relu', True, True, False), + # Fused batch norm always has scaling enabled. + (nn_ops.relu6, 'Relu6', False, True, True), + (nn_ops.relu, 'Relu', False, True, True), + (nn_ops.relu6, 'Relu6', True, True, True), + (nn_ops.relu, 'Relu', True, True, True), ] - for parameters in parameters_list: - test_fn(parameters[0], parameters[1], parameters[2]) + for params in parameters_list: + test_fn(params[0], params[1], params[2], params[3], params[4]) - def testFailsWithFusedBatchNorm(self): - self._RunTestOverParameters(self._TestFailsWithFusedBatchNorm) - - def _TestFailsWithFusedBatchNorm(self, relu, relu_op_name, with_bypass): - """Tests that batch norm fails when fused batch norm ops are present.""" - g = ops.Graph() - with g.as_default(): - batch_size, height, width = 5, 128, 128 - inputs = array_ops.zeros((batch_size, height, width, 3)) - out_depth = 3 if with_bypass else 32 - stride = 1 if with_bypass else 2 - activation_fn = None if with_bypass else relu - batch_norm_params = _DEFAULT_BATCH_NORM_PARAMS.copy() - batch_norm_params['fused'] = True - scope = 'test/test2' if with_bypass else 'test' - node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=batch_norm_params, - scope=scope) - if with_bypass: - node = math_ops.add(inputs, node, name='test/Add') - relu(node, name='test/' + relu_op_name) - - with self.assertRaises(ValueError): - fold_batch_norms.FoldBatchNorms(g) - - def _TestFoldConv2d(self, relu, relu_op_name, with_bypass): + def _TestFoldConv2d(self, relu, relu_op_name, with_bypass, has_scaling, + fused_batch_norm): """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. Args: @@ -93,6 +67,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): relu_op_name: String, name of the Relu* operation. with_bypass: Bool, when true there is an extra connection added from inputs to just before Relu*. + has_scaling: Bool, when true the batch norm has scaling. + fused_batch_norm: Bool, when true the batch norm is fused. """ g = ops.Graph() with g.as_default(): @@ -102,12 +78,17 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): stride = 1 if with_bypass else 2 activation_fn = None if with_bypass else relu scope = 'test/test2' if with_bypass else 'test' - node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, - scope=scope) + node = conv2d( + inputs, + out_depth, [5, 5], + stride=stride, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams( + scale=has_scaling, fused=fused_batch_norm), + scope=scope) if with_bypass: node = math_ops.add(inputs, node, name='test/Add') relu(node, name='test/' + relu_op_name) @@ -116,9 +97,10 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_mul = g.get_operation_by_name(scope + '/mul_fold') self.assertEqual(folded_mul.type, 'Mul') - self._AssertInputOpsAre(folded_mul, - [scope + '/weights/read', - scope + '/BatchNorm/batchnorm/mul']) + self._AssertInputOpsAre(folded_mul, [ + scope + '/weights/read', + self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm) + ]) self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold']) folded_conv = g.get_operation_by_name(scope + '/convolution_Fold') @@ -129,16 +111,18 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_add = g.get_operation_by_name(scope + '/add_fold') self.assertEqual(folded_add.type, 'Add') - self._AssertInputOpsAre(folded_add, - [scope + '/convolution_Fold', - scope + '/BatchNorm/batchnorm/sub']) + self._AssertInputOpsAre(folded_add, [ + scope + '/convolution_Fold', + self._BathNormBiasName(scope, fused_batch_norm) + ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) def testFoldConv2d(self): self._RunTestOverParameters(self._TestFoldConv2d) - def _TestFoldConv2dUnknownShape(self, relu, relu_op_name, with_bypass): + def _TestFoldConv2dUnknownShape(self, relu, relu_op_name, with_bypass, + has_scaling, fused_batch_norm): """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. Tests that folding works even with an input shape where some dimensions are @@ -149,6 +133,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): relu_op_name: String, name of the Relu* operation. with_bypass: Bool, when true there is an extra connection added from inputs to just before Relu*. + has_scaling: Bool, when true the batch norm has scaling. + fused_batch_norm: Bool, when true the batch norm is fused. """ g = ops.Graph() with g.as_default(): @@ -165,7 +151,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): weights_initializer=self._WeightInit(0.09), activation_fn=activation_fn, normalizer_fn=batch_norm, - normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, + normalizer_params=self._BatchNormParams( + scale=has_scaling, fused=fused_batch_norm), scope=scope) if with_bypass: node = math_ops.add(inputs, node, name='test/Add') @@ -176,7 +163,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_mul = g.get_operation_by_name(scope + '/mul_fold') self.assertEqual(folded_mul.type, 'Mul') self._AssertInputOpsAre(folded_mul, [ - scope + '/weights/read', scope + '/BatchNorm/batchnorm/mul' + scope + '/weights/read', + self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm) ]) self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold']) @@ -188,7 +176,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_add = g.get_operation_by_name(scope + '/add_fold') self.assertEqual(folded_add.type, 'Add') self._AssertInputOpsAre(folded_add, [ - scope + '/convolution_Fold', scope + '/BatchNorm/batchnorm/sub' + scope + '/convolution_Fold', + self._BathNormBiasName(scope, fused_batch_norm) ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) @@ -196,62 +185,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): def testFoldConv2dUnknownShape(self): self._RunTestOverParameters(self._TestFoldConv2dUnknownShape) - def _TestFoldConv2dWithoutScale(self, relu, relu_op_name, with_bypass): - """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. - - Args: - relu: Callable that returns an Operation, a factory method for the Relu*. - relu_op_name: String, name of the Relu* operation. - with_bypass: Bool, when true there is an extra connection added from - inputs to just before Relu*. - """ - g = ops.Graph() - with g.as_default(): - batch_size, height, width = 5, 128, 128 - inputs = array_ops.zeros((batch_size, height, width, 3)) - out_depth = 3 if with_bypass else 32 - stride = 1 if with_bypass else 2 - activation_fn = None if with_bypass else relu - bn_params = copy.copy(_DEFAULT_BATCH_NORM_PARAMS) - bn_params['scale'] = False - scope = 'test/test2' if with_bypass else 'test' - node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=bn_params, - scope=scope) - if with_bypass: - node = math_ops.add(inputs, node, name='test/Add') - relu(node, name='test/' + relu_op_name) - - fold_batch_norms.FoldBatchNorms(g) - - folded_mul = g.get_operation_by_name(scope + '/mul_fold') - self.assertEqual(folded_mul.type, 'Mul') - self._AssertInputOpsAre(folded_mul, - [scope + '/weights/read', - scope + '/BatchNorm/batchnorm/Rsqrt']) - self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold']) - - folded_conv = g.get_operation_by_name(scope + '/convolution_Fold') - self.assertEqual(folded_conv.type, 'Conv2D') - self._AssertInputOpsAre(folded_conv, - [scope + '/mul_fold', inputs.op.name]) - self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) - - folded_add = g.get_operation_by_name(scope + '/add_fold') - self.assertEqual(folded_add.type, 'Add') - self._AssertInputOpsAre(folded_add, - [scope + '/convolution_Fold', - scope + '/BatchNorm/batchnorm/sub']) - output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] - self._AssertOutputGoesToOps(folded_add, g, output_op_names) - - def testFoldConv2dWithoutScale(self): - self._RunTestOverParameters(self._TestFoldConv2dWithoutScale) - - def _TestFoldFullyConnectedLayer(self, relu, relu_op_name, with_bypass): + def _TestFoldFullyConnectedLayer(self, relu, relu_op_name, with_bypass, + has_scaling, fused_batch_norm): """Tests folding cases: inputs -> FC with batch norm -> Relu*. Args: @@ -259,6 +194,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): relu_op_name: String, name of the Relu* operation. with_bypass: Bool, when true there is an extra connection added from inputs to just before Relu*. + has_scaling: Bool, when true the batch norm has scaling. + fused_batch_norm: Bool, when true the batch norm is fused. """ g = ops.Graph() with g.as_default(): @@ -267,12 +204,15 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): out_depth = 256 if with_bypass else 128 activation_fn = None if with_bypass else relu scope = 'test/test2' if with_bypass else 'test' - node = fully_connected(inputs, out_depth, - weights_initializer=self._WeightInit(0.03), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, - scope=scope) + node = fully_connected( + inputs, + out_depth, + weights_initializer=self._WeightInit(0.03), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams( + scale=has_scaling, fused=fused_batch_norm), + scope=scope) if with_bypass: node = math_ops.add(inputs, node, name='test/Add') relu(node, name='test/' + relu_op_name) @@ -281,9 +221,10 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_mul = g.get_operation_by_name(scope + '/mul_fold') self.assertEqual(folded_mul.type, 'Mul') - self._AssertInputOpsAre(folded_mul, - [scope + '/weights/read', - scope + '/BatchNorm/batchnorm/mul']) + self._AssertInputOpsAre(folded_mul, [ + scope + '/weights/read', + self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm) + ]) self._AssertOutputGoesToOps(folded_mul, g, [scope + '/MatMul_Fold']) folded_conv = g.get_operation_by_name(scope + '/MatMul_Fold') @@ -294,71 +235,18 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_add = g.get_operation_by_name(scope + '/add_fold') self.assertEqual(folded_add.type, 'Add') - self._AssertInputOpsAre(folded_add, - [scope + '/MatMul_Fold', - scope + '/BatchNorm/batchnorm/sub']) + self._AssertInputOpsAre(folded_add, [ + scope + '/MatMul_Fold', + self._BathNormBiasName(scope, fused_batch_norm) + ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) def testFoldFullyConnectedLayer(self): self._RunTestOverParameters(self._TestFoldFullyConnectedLayer) - def _TestFoldFullyConnectedLayerWithoutScale(self, relu, relu_op_name, - with_bypass): - """Tests folding cases: inputs -> FC with batch norm -> Relu*. - - Args: - relu: Callable that returns an Operation, a factory method for the Relu*. - relu_op_name: String, name of the Relu* operation. - with_bypass: Bool, when true there is an extra connection added from - inputs to just before Relu*. - """ - g = ops.Graph() - with g.as_default(): - batch_size, depth = 5, 256 - inputs = array_ops.zeros((batch_size, depth)) - out_depth = 256 if with_bypass else 128 - activation_fn = None if with_bypass else relu - bn_params = copy.copy(_DEFAULT_BATCH_NORM_PARAMS) - bn_params['scale'] = False - scope = 'test/test2' if with_bypass else 'test' - node = fully_connected(inputs, out_depth, - weights_initializer=self._WeightInit(0.03), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=bn_params, - scope=scope) - if with_bypass: - node = math_ops.add(inputs, node, name='test/Add') - relu(node, name='test/' + relu_op_name) - - fold_batch_norms.FoldBatchNorms(g) - - folded_mul = g.get_operation_by_name(scope + '/mul_fold') - self.assertEqual(folded_mul.type, 'Mul') - self._AssertInputOpsAre(folded_mul, - [scope + '/weights/read', - scope + '/BatchNorm/batchnorm/Rsqrt']) - self._AssertOutputGoesToOps(folded_mul, g, [scope + '/MatMul_Fold']) - - folded_conv = g.get_operation_by_name(scope + '/MatMul_Fold') - self.assertEqual(folded_conv.type, 'MatMul') - self._AssertInputOpsAre(folded_conv, - [scope + '/mul_fold', inputs.op.name]) - self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) - - folded_add = g.get_operation_by_name(scope + '/add_fold') - self.assertEqual(folded_add.type, 'Add') - self._AssertInputOpsAre(folded_add, - [scope + '/MatMul_Fold', - scope + '/BatchNorm/batchnorm/sub']) - output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] - self._AssertOutputGoesToOps(folded_add, g, output_op_names) - - def testFoldFullyConnectedLayerWithoutScale(self): - self._RunTestOverParameters(self._TestFoldFullyConnectedLayerWithoutScale) - - def _TestFoldDepthwiseConv2d(self, relu, relu_op_name, with_bypass): + def _TestFoldDepthwiseConv2d(self, relu, relu_op_name, with_bypass, + has_scaling, fused_batch_norm): """Tests folding: inputs -> DepthwiseConv2d with batch norm -> Relu*. Args: @@ -366,6 +254,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): relu_op_name: String, name of the Relu* operation. with_bypass: Bool, when true there is an extra connection added from inputs to just before Relu*. + has_scaling: Bool, when true the batch norm has scaling. + fused_batch_norm: Bool, when true the batch norm is fused. """ g = ops.Graph() with g.as_default(): @@ -374,13 +264,18 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): stride = 1 if with_bypass else 2 activation_fn = None if with_bypass else relu scope = 'test/test2' if with_bypass else 'test' - node = separable_conv2d(inputs, None, [5, 5], stride=stride, - depth_multiplier=1.0, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, - scope=scope) + node = separable_conv2d( + inputs, + None, [5, 5], + stride=stride, + depth_multiplier=1.0, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams( + scale=has_scaling, fused=fused_batch_norm), + scope=scope) if with_bypass: node = math_ops.add(inputs, node, name='test/Add') relu(node, name='test/' + relu_op_name) @@ -396,9 +291,10 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): scale_reshape = g.get_operation_by_name(scope + '/scale_reshape') self.assertEqual(scale_reshape.type, 'Reshape') - self._AssertInputOpsAre(scale_reshape, - [scope + '/BatchNorm/batchnorm/mul', - scope + '/scale_reshape/shape']) + self._AssertInputOpsAre(scale_reshape, [ + self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm), + scope + '/scale_reshape/shape' + ]) self._AssertOutputGoesToOps(scale_reshape, g, [scope + '/mul_fold']) folded_conv = g.get_operation_by_name(scope + '/depthwise_Fold') @@ -409,77 +305,35 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_add = g.get_operation_by_name(scope + '/add_fold') self.assertEqual(folded_add.type, 'Add') - self._AssertInputOpsAre(folded_add, - [scope + '/depthwise_Fold', - scope + '/BatchNorm/batchnorm/sub']) + self._AssertInputOpsAre(folded_add, [ + scope + '/depthwise_Fold', + self._BathNormBiasName(scope, fused_batch_norm) + ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) def testFoldDepthwiseConv2d(self): self._RunTestOverParameters(self._TestFoldDepthwiseConv2d) - def _TestFoldDepthwiseConv2dWithoutScale(self, relu, relu_op_name, - with_bypass): - """Tests folding: inputs -> DepthwiseConv2d with batch norm -> Relu*. + def _BatchNormParams(self, scale=True, fused=False): + return { + 'center': True, + 'scale': scale, + 'decay': 1.0 - 0.003, + 'fused': fused + } - Args: - relu: Callable that returns an Operation, a factory method for the Relu*. - relu_op_name: String, name of the Relu* operation. - with_bypass: Bool, when true there is an extra connection added from - inputs to just before Relu*. - """ - g = ops.Graph() - with g.as_default(): - batch_size, height, width = 5, 128, 128 - inputs = array_ops.zeros((batch_size, height, width, 3)) - stride = 1 if with_bypass else 2 - activation_fn = None if with_bypass else relu - bn_params = copy.copy(_DEFAULT_BATCH_NORM_PARAMS) - bn_params['scale'] = False - scope = 'test/test2' if with_bypass else 'test' - node = separable_conv2d(inputs, None, [5, 5], stride=stride, - depth_multiplier=1.0, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=bn_params, - scope=scope) - if with_bypass: - node = math_ops.add(inputs, node, name='test/Add') - relu(node, name='test/' + relu_op_name) + def _BatchNormMultiplierName(self, scope, has_scaling, fused): + if has_scaling: + if fused: + return scope + '/mul' + return scope + '/BatchNorm/batchnorm/mul' + return scope + '/BatchNorm/batchnorm/Rsqrt' - fold_batch_norms.FoldBatchNorms(g) - - folded_mul = g.get_operation_by_name(scope + '/mul_fold') - self.assertEqual(folded_mul.type, 'Mul') - self._AssertInputOpsAre(folded_mul, - [scope + '/depthwise_weights/read', - scope + '/scale_reshape']) - self._AssertOutputGoesToOps(folded_mul, g, [scope + '/depthwise_Fold']) - - scale_reshape = g.get_operation_by_name(scope + '/scale_reshape') - self.assertEqual(scale_reshape.type, 'Reshape') - self._AssertInputOpsAre(scale_reshape, - [scope + '/BatchNorm/batchnorm/Rsqrt', - scope + '/scale_reshape/shape']) - self._AssertOutputGoesToOps(scale_reshape, g, [scope + '/mul_fold']) - - folded_conv = g.get_operation_by_name(scope + '/depthwise_Fold') - self.assertEqual(folded_conv.type, 'DepthwiseConv2dNative') - self._AssertInputOpsAre(folded_conv, - [scope + '/mul_fold', inputs.op.name]) - self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) - - folded_add = g.get_operation_by_name(scope + '/add_fold') - self.assertEqual(folded_add.type, 'Add') - self._AssertInputOpsAre(folded_add, - [scope + '/depthwise_Fold', - scope + '/BatchNorm/batchnorm/sub']) - output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] - self._AssertOutputGoesToOps(folded_add, g, output_op_names) - - def testFoldDepthwiseConv2dWithoutScale(self): - self._RunTestOverParameters(self._TestFoldDepthwiseConv2dWithoutScale) + def _BathNormBiasName(self, scope, fused): + if fused: + return scope + '/bias' + return scope + '/BatchNorm/batchnorm/sub' def _WeightInit(self, stddev): """Returns a truncated normal variable initializer. diff --git a/tensorflow/contrib/quantize/python/graph_matcher.py b/tensorflow/contrib/quantize/python/graph_matcher.py new file mode 100644 index 00000000000..e3581cc5590 --- /dev/null +++ b/tensorflow/contrib/quantize/python/graph_matcher.py @@ -0,0 +1,200 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities that match patterns in a tf.Graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class OpTypePattern(object): + """A tree pattern that matches TF expressions with certain op types.""" + + def __init__(self, op_type, name=None, inputs=None): + """Initializes an OpTypePattern. + + Args: + op_type: string that specifies the allowed types of the root. It can be + (1) an op type, e.g. 'Conv2D', + (2) '*', i.e. wildcard, or + (3) multiple op types separated by '|', e.g., 'Relu|Relu6'. + We could use regex strings, which might be worthwhile when we have many + similar TF op types. + name: Optional string. The name of the pattern that can be looked up in + MatchResult. + inputs: Optional list of `OpTypePattern`s or strings that specify the + patterns for the inputs of a matching op. If None, this pattern accepts + any inputs of a matching op. + """ + self._op_type = op_type + self._name = name + if inputs is None: + inputs = [] + self._inputs = [ + input_pattern if isinstance(input_pattern, OpTypePattern) else + OpTypePattern(input_pattern) for input_pattern in inputs + ] + + @property + def op_type(self): + return self._op_type + + @property + def inputs(self): + return self._inputs + + @property + def name(self): + return self._name + + +class MatchResult(object): + r"""Encapsulates the result of a match done by GraphMatcher. + + MatchResult contains a map from OpTypePattern to the matching op and tensor. + When the matching op has multiple output tensors, the matching tensor is the + output tensor used by the matching op of the parent pattern. E.g., when we + match graph + + - + + / \y0 y1/ \ + x split z + | + y (nodes are ops; edges are going up) + + against add_pattern defined as + + y1_pattern = OpTypePattern('*') + z_pattern = OpTypePattern('*') + add_pattern = OpTypePattern('+', inputs=[y1_pattern, z_pattern]) + + the matching op of `y1_pattern` is `split`, and the matching tensor of + `y1_pattern` + is `y1` not `y0`. + """ + + def __init__(self): + self._pattern_to_op_tensor = {} + self._name_to_pattern = {} + + def add(self, pattern, op, tensor): + self._pattern_to_op_tensor[pattern] = op, tensor + if pattern.name is not None: + if pattern.name in self._name_to_pattern: + raise ValueError( + 'Name %s is already bound to another pattern' % pattern.name) + self._name_to_pattern[pattern.name] = pattern + + def _to_pattern(self, pattern_or_name): + if isinstance(pattern_or_name, OpTypePattern): + return pattern_or_name + + if isinstance(pattern_or_name, str): + return self._name_to_pattern[pattern_or_name] + + raise ValueError('pattern_or_name has type %s. Expect OpTypePattern or str.' + % type(pattern_or_name)) + + def get_op(self, pattern_or_name): + return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][0] + + def get_tensor(self, pattern_or_name): + return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][1] + + +class GraphMatcher(object): + """Checks if a particular subgraph matches a given pattern.""" + + def __init__(self, pattern): + """Initializes a GraphMatcher. + + Args: + pattern: The `OpTypePattern` against which `GraphMatcher` matches + subgraphs. + """ + self._pattern = pattern + + def _match_pattern(self, pattern, op, tensor): + """Returns whether an TF expression rooted at `op` matches `pattern`. + + If there is a match, adds to `self._match_result` the matching op and tensor + with key `pattern`. + + Args: + pattern: An `OpTypePattern`. + op: A `tf.Operation` to match against the pattern. + tensor: the output `tf.Tensor` of `op` that is used by the matching op of + `pattern`'s parent. Can be None if `pattern` is already the root of the + pattern tree. + + Returns: + True if an TF expression rooted at `op` matches `pattern`. + """ + if pattern.op_type != '*': + if op.type not in pattern.op_type.split('|'): + return False + + self._match_result.add(pattern, op, tensor) + + if not pattern.inputs: + # If pattern.inputs is empty, skips the rest and accepts all the inputs. + return True + + return len(op.inputs) == len(pattern.inputs) and all([ + self._match_pattern(input_pattern, input_tensor.op, input_tensor) + for input_tensor, input_pattern in zip(op.inputs, pattern.inputs) + ]) + + def match_op(self, op): + """Matches `op` against `self._pattern`. + + Args: + op: `tf.Operation` to match against the pattern. + + Returns: + Returns a `MatchResult` if `op` matches the pattern; otherwise, returns + None. + """ + self._match_result = MatchResult() + if not self._match_pattern(self._pattern, op, tensor=None): + return None + return self._match_result + + def match_ops(self, ops): + """Matches each operation in `ops` against `self._pattern`. + + Args: + ops: collection of `tf.Operation` to match against the pattern. + + Yields: + `MatchResult` for each `tf.Operation` that matches the pattern. + """ + for op in ops: + match_result = self.match_op(op) + if match_result: + yield match_result + + def match_graph(self, graph): + """Matches each operation in `graph` against `self._pattern`. + + Args: + graph: `tf.Graph` containing operations to match. + + Yields: + `MatchResult` for each `tf.Operation` in `graph` that matches the pattern. + """ + # Python 3.3.2+ implements `yield from`, but for now: + for match_result in self.match_ops(graph.get_operations()): + yield match_result diff --git a/tensorflow/contrib/quantize/python/graph_matcher_test.py b/tensorflow/contrib/quantize/python/graph_matcher_test.py new file mode 100644 index 00000000000..e1572865e42 --- /dev/null +++ b/tensorflow/contrib/quantize/python/graph_matcher_test.py @@ -0,0 +1,130 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for graph_matcher.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.framework.python import ops as contrib_ops +from tensorflow.contrib.layers.python.layers import initializers +from tensorflow.contrib.layers.python.layers import layers +from tensorflow.contrib.quantize.python import graph_matcher +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import googletest + + +class GraphMatcherTest(test_util.TensorFlowTestCase): + + def test_conv_layer(self): + g = ops.Graph() + with g.as_default(): + inputs = array_ops.placeholder(dtypes.float32, shape=[8, 5, 5, 3]) + + with contrib_ops.arg_scope( + [layers.batch_norm], fused=True, is_training=True, trainable=True): + return layers.convolution( + inputs, + num_outputs=16, + kernel_size=3, + stride=1, + padding='VALID', + activation_fn=nn_ops.relu, + normalizer_fn=layers.batch_norm, + normalizer_params={}, + weights_initializer=initializers.xavier_initializer(), + weights_regularizer=None, + biases_initializer=init_ops.zeros_initializer(), + biases_regularizer=None, + reuse=None, + trainable=True, + scope=None) + + inputs_pattern = graph_matcher.OpTypePattern('*', name='inputs') + relu_pattern = graph_matcher.OpTypePattern( + 'Relu', + name='relu', + inputs=[ + graph_matcher.OpTypePattern( + 'FusedBatchNorm', + inputs=[ + graph_matcher.OpTypePattern( + 'Conv2D', inputs=[inputs_pattern, '*']), '*', '*', '*', + '*' + ]) + ]) + matcher = graph_matcher.GraphMatcher(relu_pattern) + match_results = list(matcher.match_graph(g)) + self.assertEqual(1, len(match_results)) + match_result = match_results[0] + self.assertEqual(match_result.get_tensor(inputs_pattern), inputs) + self.assertEqual(match_result.get_tensor('inputs'), inputs) + + def test_multiple_outputs(self): + # - + + # / \y0 y1/ \ + # x split z + # | + # y (nodes are ops; edges are going up) + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtypes.float32, shape=[1], name='x') + y = array_ops.placeholder(dtypes.float32, shape=[2], name='y') + y0, y1 = array_ops.split(y, num_or_size_splits=2, axis=0) + z = array_ops.placeholder(dtypes.float32, shape=[1], name='z') + math_ops.add(x, y0) + math_ops.subtract(y1, z) + + y1_pattern = graph_matcher.OpTypePattern('*') + minus_pattern = graph_matcher.OpTypePattern('Sub', inputs=[y1_pattern, '*']) + matcher = graph_matcher.GraphMatcher(minus_pattern) + + match_results = list(matcher.match_graph(g)) + self.assertEqual(1, len(match_results)) + match_result = match_results[0] + + self.assertEqual(y0.op, y1.op) + self.assertEqual(match_result.get_op(y1_pattern), y1.op) + self.assertEqual(match_result.get_tensor(y1_pattern), y1) + + def test_oneof_pattern(self): + # - + + # / \ / \ + # x y z + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtypes.float32, shape=[], name='x') + y = array_ops.placeholder(dtypes.float32, shape=[], name='y') + z = array_ops.placeholder(dtypes.float32, shape=[], name='z') + plus = x + y + minus = y - z + + add_or_sub_pattern = graph_matcher.OpTypePattern( + 'Add|Sub', inputs=['*', '*']) + matcher = graph_matcher.GraphMatcher(add_or_sub_pattern) + self.assertEqual([ + match_result.get_op(add_or_sub_pattern) + for match_result in matcher.match_graph(g) + ], [plus.op, minus.op]) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py index b5a32a7266a..31fcd66dfb7 100644 --- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py +++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.layers.python.layers import layers +from tensorflow.contrib.quantize.python import fold_batch_norms from tensorflow.contrib.quantize.python import quantize from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -35,18 +36,11 @@ conv2d = layers.conv2d fully_connected = layers.fully_connected separable_conv2d = layers.separable_conv2d -_DEFAULT_BATCH_NORM_PARAMS = { - 'center': True, - 'scale': True, - 'decay': 1.0 - 0.003, - 'fused': False, -} - -# TODO(suharshs): Use parameterized test once OSS TF supports it. class QuantizeTest(test_util.TensorFlowTestCase): - def _RunTestOverParameters(self, test_fn): + def _RunWithoutBatchNormTestOverParameters(self, test_fn): + # TODO(suharshs): Use parameterized test once OSS TF supports it. parameters_list = [ # (activation, activation_op_name, with_bypass, delay) (nn_ops.relu6, 'Relu6', False, None), @@ -60,10 +54,10 @@ class QuantizeTest(test_util.TensorFlowTestCase): (array_ops.identity, 'Identity', True, None), (nn_ops.relu6, 'Relu6', True, 5000), (nn_ops.relu, 'Relu', True, 5000), - (array_ops.identity, 'Identity', True, 5000) + (array_ops.identity, 'Identity', True, 5000), ] - for parameters in parameters_list: - test_fn(parameters[0], parameters[1], parameters[2], parameters[3]) + for params in parameters_list: + test_fn(params[0], params[1], params[2], params[3]) def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name, with_bypass, delay): @@ -137,7 +131,8 @@ class QuantizeTest(test_util.TensorFlowTestCase): self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) def testQuantize_Conv2dWithoutBatchNorm(self): - self._RunTestOverParameters(self._TestQuantize_Conv2dWithoutBatchNorm) + self._RunWithoutBatchNormTestOverParameters( + self._TestQuantize_Conv2dWithoutBatchNorm) def _TestQuantize_FCWithoutBatchNorm(self, activation, activation_op_name, with_bypass, delay): @@ -210,7 +205,8 @@ class QuantizeTest(test_util.TensorFlowTestCase): self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) def testQuantize_FCWithoutBatchNorm(self): - self._RunTestOverParameters(self._TestQuantize_FCWithoutBatchNorm) + self._RunWithoutBatchNormTestOverParameters( + self._TestQuantize_FCWithoutBatchNorm) def _TestQuantize_DepthwiseConv2dWithoutBatchNorm( self, activation, activation_op_name, with_bypass, delay): @@ -284,11 +280,43 @@ class QuantizeTest(test_util.TensorFlowTestCase): self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) def testQuantize_DepthwiseConv2dWithoutBatchNorm(self): - self._RunTestOverParameters( + self._RunWithoutBatchNormTestOverParameters( self._TestQuantize_DepthwiseConv2dWithoutBatchNorm) + def _RunBatchNormTestOverParameters(self, test_fn): + # TODO(suharshs): Use parameterized test once OSS TF supports it. + parameters_list = [ + # (activation, activation_op_name, with_bypass, delay, fused_batch_norm) + (nn_ops.relu6, 'Relu6', False, None, False), + (nn_ops.relu, 'Relu', False, None, False), + (array_ops.identity, 'Identity', False, None, False), + (nn_ops.relu6, 'Relu6', False, 5000, False), + (nn_ops.relu, 'Relu', False, 5000, False), + (array_ops.identity, 'Identity', False, 5000, False), + (nn_ops.relu6, 'Relu6', True, None, False), + (nn_ops.relu, 'Relu', True, None, False), + (array_ops.identity, 'Identity', True, None, False), + (nn_ops.relu6, 'Relu6', True, 5000, False), + (nn_ops.relu, 'Relu', True, 5000, False), + (array_ops.identity, 'Identity', True, 5000, False), + (nn_ops.relu6, 'Relu6', False, None, True), + (nn_ops.relu, 'Relu', False, None, True), + (array_ops.identity, 'Identity', False, None, True), + (nn_ops.relu6, 'Relu6', False, 5000, True), + (nn_ops.relu, 'Relu', False, 5000, True), + (array_ops.identity, 'Identity', False, 5000, True), + (nn_ops.relu6, 'Relu6', True, None, True), + (nn_ops.relu, 'Relu', True, None, True), + (array_ops.identity, 'Identity', True, None, True), + (nn_ops.relu6, 'Relu6', True, 5000, True), + (nn_ops.relu, 'Relu', True, 5000, True), + (array_ops.identity, 'Identity', True, 5000, True) + ] + for params in parameters_list: + test_fn(params[0], params[1], params[2], params[3], params[4]) + def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay): + with_bypass, delay, fused_batch_norm): """Tests quantization: inputs -> Conv2d with batch norm -> Activation. Args: @@ -298,25 +326,29 @@ class QuantizeTest(test_util.TensorFlowTestCase): with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. """ self._testQuantize_Conv2dWithBatchNorm( activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema=True) self._testQuantize_Conv2dWithBatchNorm( activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema=False) def testQuantize_Conv2dWithBatchNorm(self): - self._RunTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm) + self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm) def _testQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay, use_ema): + with_bypass, delay, fused_batch_norm, + use_ema): """Tests quantization: inputs -> Conv2d with batch norm -> Activation. Args: @@ -326,6 +358,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. use_ema: Bool, when true uses EMA quantization for BN folded weights. """ graph = ops.Graph() @@ -337,39 +370,29 @@ class QuantizeTest(test_util.TensorFlowTestCase): stride = 1 if with_bypass else 2 out_depth = 3 if with_bypass else 32 scope = 'test/test2' if with_bypass else 'test' - node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=None, - normalizer_fn=batch_norm, - normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, - scope=scope) - # Manually fold the batch norm. - weights = graph.get_operation_by_name(scope + '/weights/read').outputs[0] - bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul') - .outputs[0]) - mul_fold = math_ops.multiply(weights, bn_mult, name=scope + '/mul_fold') - stride = [stride, stride] - conv_fold = nn_ops.convolution( - input=inputs, - filter=mul_fold, + node = conv2d( + inputs, + out_depth, [5, 5], + stride=stride, padding='SAME', - strides=stride, - data_format='NHWC', - name=scope + '/convolution_Fold') - bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub') - .outputs[0]) - add_fold = math_ops.add(conv_fold, bn_bias, name=scope + '/add_fold') + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams(fused_batch_norm), + scope=scope) + # Manually add a bypass (optionaly) and an activation. if with_bypass: - node = math_ops.add(inputs, add_fold, name='test/Add') - else: - node = add_fold + node = math_ops.add(inputs, node, name='test/Add') + node = activation(node, name='test/' + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') + fold_batch_norms.FoldBatchNorms(graph) + quantize.Quantize( graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) @@ -413,7 +436,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay): + with_bypass, delay, fused_batch_norm): """Tests quantization: inputs -> FC with batch norm -> Activation. Args: @@ -423,25 +446,29 @@ class QuantizeTest(test_util.TensorFlowTestCase): with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. """ self._testQuantize_FCWithBatchNorm( activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema=True) self._testQuantize_FCWithBatchNorm( activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema=False) def testQuantize_FCWithBatchNorm(self): - self._RunTestOverParameters(self._TestQuantize_FCWithBatchNorm) + self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm) def _testQuantize_FCWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay, use_ema): + with_bypass, delay, fused_batch_norm, + use_ema): """Tests quantization: inputs -> FC with batch norm -> Activation. Args: @@ -451,6 +478,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. use_ema: Bool, when true uses EMA quantization for BN folded weights. """ graph = ops.Graph() @@ -461,32 +489,27 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs = array_ops.zeros((batch_size, depth)) out_depth = 256 if with_bypass else 128 scope = 'test/test2' if with_bypass else 'test' - node = fully_connected(inputs, out_depth, - weights_initializer=self._WeightInit(0.03), - activation_fn=None, - normalizer_fn=batch_norm, - normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, - scope=scope) - # Manually fold the batch norm. - weights = graph.get_operation_by_name(scope + '/weights/read').outputs[0] - bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul') - .outputs[0]) - mul_fold = math_ops.multiply(weights, bn_mult, name=scope + '/mul_fold') - fc_fold = math_ops.matmul(inputs, mul_fold, name=scope + '/MatMul_Fold') - bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub') - .outputs[0]) - add_fold = math_ops.add(fc_fold, bn_bias, name=scope + '/add_fold') + node = fully_connected( + inputs, + out_depth, + weights_initializer=self._WeightInit(0.03), + activation_fn=None, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams(fused_batch_norm), + scope=scope) + # Manually add a bypass (optionaly) and an activation. if with_bypass: - node = math_ops.add(inputs, add_fold, name='test/Add') - else: - node = add_fold + node = math_ops.add(inputs, node, name='test/Add') + node = activation(node, name='test/' + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') + fold_batch_norms.FoldBatchNorms(graph) + quantize.Quantize( graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) @@ -530,7 +553,8 @@ class QuantizeTest(test_util.TensorFlowTestCase): self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) def _TestQuantize_DepthwiseConv2dWithBatchNorm( - self, activation, activation_op_name, with_bypass, delay): + self, activation, activation_op_name, with_bypass, delay, + fused_batch_norm): """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. Args: @@ -540,26 +564,30 @@ class QuantizeTest(test_util.TensorFlowTestCase): with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. """ self._testQuantize_DepthwiseConv2dWithBatchNorm( activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema=True) self._testQuantize_DepthwiseConv2dWithBatchNorm( activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema=False) def testQuantize_DepthwiseConv2dWithBatchNorm(self): - self._RunTestOverParameters( - self._TestQuantize_DepthwiseConv2dWithoutBatchNorm) + self._RunBatchNormTestOverParameters( + self._TestQuantize_DepthwiseConv2dWithBatchNorm) def _testQuantize_DepthwiseConv2dWithBatchNorm( - self, activation, activation_op_name, with_bypass, delay, use_ema): + self, activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema): """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. Args: @@ -569,6 +597,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. use_ema: Bool, when true uses EMA quantization for BN folded weights. """ graph = ops.Graph() @@ -579,46 +608,30 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs = array_ops.zeros((batch_size, height, width, depth)) stride = 1 if with_bypass else 2 scope = 'test/test2' if with_bypass else 'test' - node = separable_conv2d(inputs, None, [5, 5], stride=stride, - depth_multiplier=1.0, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=None, - normalizer_fn=batch_norm, - normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, - scope=scope) - # Manually fold the batch norm. - weights = (graph.get_operation_by_name(scope + '/depthwise_weights/read') - .outputs[0]) - bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul') - .outputs[0]) - new_shape = [ - weights.get_shape().as_list()[2], weights.get_shape().as_list()[3] - ] - bn_mult_reshaped = array_ops.reshape( - bn_mult, new_shape, name=scope + '/gamma_reshape') - mul_fold = math_ops.multiply( - weights, bn_mult_reshaped, name=scope + '/mul_fold') - stride = [1, stride, stride, 1] - conv_fold = nn_ops.depthwise_conv2d( - input=inputs, - filter=mul_fold, + node = separable_conv2d( + inputs, + None, [5, 5], + stride=stride, + depth_multiplier=1.0, padding='SAME', - strides=stride, - name=scope + '/depthwise_Fold') - bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub') - .outputs[0]) - add_fold = math_ops.add(conv_fold, bn_bias, name=scope + '/add_fold') + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams(fused_batch_norm), + scope=scope) + # Manually add a bypass (optionaly) and an activation. if with_bypass: - node = math_ops.add(inputs, add_fold, name='test/Add') - else: - node = add_fold + node = math_ops.add(inputs, node, name='test/Add') + node = activation(node, name='test/' + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') + fold_batch_norms.FoldBatchNorms(graph) + quantize.Quantize( graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) quantization_node_name = 'FakeQuantWithMinMaxVars' @@ -660,6 +673,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): if delay else 'control_dependency') self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + def _BatchNormParams(self, fused=False): + return {'center': True, 'scale': True, 'decay': 1.0 - 0.003, 'fused': fused} + def _WeightInit(self, stddev): """Returns truncated normal variable initializer. From 93871a811eab7457f8e36ee4905234aa1a9ea8c8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Oct 2017 09:13:44 -0700 Subject: [PATCH 23/41] Remove duplicated `smart_cond()` code. PiperOrigin-RevId: 172891249 --- .../training/python/training/bucket_ops.py | 4 +-- tensorflow/python/BUILD | 1 + tensorflow/python/training/input.py | 26 ++++--------------- 3 files changed, 8 insertions(+), 23 deletions(-) diff --git a/tensorflow/contrib/training/python/training/bucket_ops.py b/tensorflow/contrib/training/python/training/bucket_ops.py index 5523cc375fc..95fbc50cba7 100644 --- a/tensorflow/contrib/training/python/training/bucket_ops.py +++ b/tensorflow/contrib/training/python/training/bucket_ops.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops @@ -47,7 +48,6 @@ _dtypes = input_py._dtypes _store_sparse_tensors = input_py._store_sparse_tensors _validate_keep_input = input_py._validate_keep_input _shapes = input_py._shapes -_smart_cond = input_py._smart_cond _which_queue = input_py._which_queue # pylint: enable=protected-access @@ -239,7 +239,7 @@ def bucket(tensors, ] return control_flow_ops.group(*enqueues, name="group_enqueues") - maybe_enqueue = _smart_cond( + maybe_enqueue = utils.smart_cond( keep_input, enqueue_which, control_flow_ops.no_op) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 21cdaec4778..e63c554e472 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2638,6 +2638,7 @@ py_library( ":init_ops", ":io_ops", ":io_ops_gen", + ":layers_base", ":lib", ":lookup_ops", ":math_ops", diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index 704017c2446..36f97960ddd 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -32,7 +32,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util +from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops @@ -413,22 +413,6 @@ def _as_original_type(original_tensors, tensor_list): return tensor_list -def _smart_cond(pred, if_true, if_false): - """A `tf.cond` that does nothing when the condition is static.""" - pred = ops.convert_to_tensor(pred) - static_pred = tensor_util.constant_value(pred) - if static_pred is not None: - if static_pred: - return if_true() - else: - return if_false() - else: - return control_flow_ops.cond( - pred, - if_true, - if_false) - - def _store_sparse_tensors(tensor_list, enqueue_many, keep_input, shared_map_ops=None): """Store SparseTensors for feeding into batch, etc. @@ -480,13 +464,13 @@ def _store_sparse_tensors(tensor_list, enqueue_many, keep_input, map_op_name = shared_map_op.name if shared_map_op else None def _maybe_store_sparse(t, map_op_name, keep_input): """Conditionally store a single sparse Tensor.""" - return _smart_cond( + return utils.smart_cond( keep_input, lambda: _store_sparse(t, shared_name=map_op_name), lambda: constant_op.constant(-1, dtypes.int64)) def _maybe_store_many_sparse(t, map_op_name, keep_input): """Conditionally store multiple sparse Tensors.""" - out_tensor = _smart_cond( + out_tensor = utils.smart_cond( keep_input, lambda: _store_many_sparse(t, shared_name=map_op_name), lambda: -1 * array_ops.ones(array_ops.shape(t)[0:1], dtypes.int64)) @@ -667,7 +651,7 @@ def _enqueue_join(queue, tensor_list_list, enqueue_many, keep_input): enqueue_ops = [enqueue_fn(_select_which_to_enqueue(x, keep_input)) for x in tensor_list_list] else: - enqueue_ops = [_smart_cond( + enqueue_ops = [utils.smart_cond( keep_input, lambda: enqueue_fn(tl), # pylint:disable=cell-var-from-loop control_flow_ops.no_op) for tl in tensor_list_list] @@ -684,7 +668,7 @@ def _enqueue(queue, tensor_list, threads, enqueue_many, keep_input): enqueue_ops = [ enqueue_fn(_select_which_to_enqueue(tensor_list, keep_input))] * threads else: - enqueue_ops = [_smart_cond( + enqueue_ops = [utils.smart_cond( keep_input, lambda: enqueue_fn(tensor_list), control_flow_ops.no_op)] * threads From 5c24b8b1e5f3f1145e123a5a159b958ea9fc8c3d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Oct 2017 09:17:03 -0700 Subject: [PATCH 24/41] XLA refactoring PiperOrigin-RevId: 172891551 --- .../xla/legacy_flags/debug_options_flags.cc | 6 ++--- tensorflow/compiler/xla/protobuf_util.cc | 25 ----------------- tensorflow/compiler/xla/protobuf_util.h | 13 +++------ .../compiler/xla/service/cpu/cpu_compiler.cc | 27 +++++++++---------- .../compiler/xla/service/gpu/gpu_compiler.cc | 10 +++---- tensorflow/compiler/xla/xla.proto | 4 +-- 6 files changed, 27 insertions(+), 58 deletions(-) diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index 8892bfbe929..f2cdd9669c7 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -206,9 +206,9 @@ void AllocateFlags() { flag_values->xla_gpu_disable_multi_streaming(), "If true, multi-streaming in the GPU backend is disabled."), tensorflow::Flag( - "xla_dump_debug_json_to", - flag_values->mutable_xla_dump_debug_json_to(), - "Dump compilation artifacts as JSON into this directory."), + "xla_dump_hlo_proto_to", + flag_values->mutable_xla_dump_hlo_proto_to(), + "Dump compilation artifacts as proto binary into this directory."), tensorflow::Flag( "xla_test_all_output_layouts", bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc index c032cb8dc5a..787725e884c 100644 --- a/tensorflow/compiler/xla/protobuf_util.cc +++ b/tensorflow/compiler/xla/protobuf_util.cc @@ -37,20 +37,6 @@ bool ProtobufEquals(const tensorflow::protobuf::Message& m1, return (serialized1 == serialized2); } -StatusOr ToJson(const tensorflow::protobuf::Message& message) { - string json_output; - tensorflow::protobuf::util::JsonPrintOptions json_options; - json_options.add_whitespace = true; - json_options.always_print_primitive_fields = true; - auto status = tensorflow::protobuf::util::MessageToJsonString( - message, &json_output, json_options); - if (!status.ok()) { - return InternalError("MessageToJsonString failed: %s", - status.error_message().data()); - } - return json_output; -} - namespace { string SanitizeFilename(const string& file_name) { @@ -65,17 +51,6 @@ string SanitizeFilename(const string& file_name) { } // namespace -Status DumpJsonToDirectory(const tensorflow::protobuf::Message& message, - const string& directory, const string& file_name) { - TF_ASSIGN_OR_RETURN(const string json_output, ToJson(message)); - - tensorflow::Env* env = tensorflow::Env::Default(); - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); - string safe_file_name = SanitizeFileName(file_name) + ".json"; - const string path = tensorflow::io::JoinPath(directory, safe_file_name); - return tensorflow::WriteStringToFile(env, path, json_output); -} - Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, const string& directory, const string& file_name) { tensorflow::Env* env = tensorflow::Env::Default(); diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h index 7accb22e0c7..3667621367c 100644 --- a/tensorflow/compiler/xla/protobuf_util.h +++ b/tensorflow/compiler/xla/protobuf_util.h @@ -32,17 +32,12 @@ namespace protobuf_util { extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1, const tensorflow::protobuf::Message& m2); -// Returns 'message' as a JSON string. -StatusOr ToJson(const tensorflow::protobuf::Message& message); - -// Writes the given message in binary proto or JSON format to the path formed by -// joining 'directory/file_name.pb' (or file_name.json). The 'directory' is -// recursively created if it doesn't already exist, and the 'file_name' is -// sanitized by replacing illegal characters with underscore '_'. +// Writes the given message in binary proto to the path formed by joining +// 'directory/file_name.pb'. The 'directory' is recursively created if it +// doesn't already exist, and the 'file_name' is sanitized by replacing +// illegal characters with underscore '_'. Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, const string& directory, const string& file_name); -Status DumpJsonToDirectory(const tensorflow::protobuf::Message& message, - const string& directory, const string& file_name); } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index ce4d109214b..06e7ec0c7cb 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -475,8 +475,8 @@ StatusOr> CpuCompiler::Compile( // ownership is std::moved. const bool embed_ir_in_executable = module->config().debug_options().xla_embed_ir_in_executable(); - const string dump_debug_json_to = - module->config().debug_options().xla_dump_debug_json_to(); + const string xla_dump_hlo_proto_to = + module->config().debug_options().xla_dump_hlo_proto_to(); if (options::CpuParallelBackendRequested(module->config())) { VLOG(1) << "Using parallel cpu backend"; @@ -496,10 +496,10 @@ StatusOr> CpuCompiler::Compile( // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); - if (!dump_debug_json_to.empty()) { + if (!xla_dump_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, dump_debug_json_to, module->name())); + TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( + proto, xla_dump_hlo_proto_to, module->name())); } // If we are using the parallel CPU backend, we need to create map from @@ -603,12 +603,11 @@ StatusOr> CpuCompiler::Compile( // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); - if (!dump_debug_json_to.empty()) { + if (!xla_dump_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, dump_debug_json_to, module->name())); + TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( + proto, xla_dump_hlo_proto_to, module->name())); } - // Each computation is a single function. Emit all embedded computations // before the entry computation. The order of computations returned from // GetEmbeddedComputations guarantees that a called computation occurs @@ -775,12 +774,12 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); - const string dump_debug_json_to = - module->config().debug_options().xla_dump_debug_json_to(); - if (!dump_debug_json_to.empty()) { + const string xla_dump_hlo_proto_to = + module->config().debug_options().xla_dump_hlo_proto_to(); + if (!xla_dump_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, dump_debug_json_to, module->name())); + TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( + proto, xla_dump_hlo_proto_to, module->name())); } IrEmitter ir_emitter(*module, *assignment, &llvm_module, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 3e16e4e3c42..9c7ca9ea38e 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -318,12 +318,12 @@ StatusOr> GpuCompiler::Compile( // print one ourselves. XLA_VLOG_LINES(2, buffer_assignment->ToString()); - const string dump_debug_json_to = - module->config().debug_options().xla_dump_debug_json_to(); - if (!dump_debug_json_to.empty()) { + const string xla_dump_hlo_proto_to = + module->config().debug_options().xla_dump_hlo_proto_to(); + if (!xla_dump_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *buffer_assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, dump_debug_json_to, module->name())); + TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( + proto, xla_dump_hlo_proto_to, module->name())); } IrEmitterContext ir_emitter_context(module.get(), buffer_assignment.get(), diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 7f4bd26d1bc..ce3c3eee68a 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -82,8 +82,8 @@ message DebugOptions { // Dump all HLO modules as text into the provided directory path. string xla_generate_hlo_text_to = 7; - // Dump compilation artifacts as JSON into this directory. - string xla_dump_debug_json_to = 8; + // Dump compilation artifacts in binary proto into this directory. + string xla_dump_hlo_proto_to = 8; // Instrument the computation to collect per-HLO cycle counts. bool xla_hlo_profile = 9; From f86588ce8fb38ab3a6afc21eb08d2a2097b56adc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Oct 2017 09:50:28 -0700 Subject: [PATCH 25/41] Added gradient op for QR decomposition PiperOrigin-RevId: 172895297 --- tensorflow/python/kernel_tests/qr_op_test.py | 66 ++++++++++++++++++-- tensorflow/python/ops/linalg_grad.py | 42 +++++++++++-- 2 files changed, 98 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/kernel_tests/qr_op_test.py b/tensorflow/python/kernel_tests/qr_op_test.py index b4fd89bd037..8848c15e765 100644 --- a/tensorflow/python/kernel_tests/qr_op_test.py +++ b/tensorflow/python/kernel_tests/qr_op_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops @@ -140,11 +141,11 @@ def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_): x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1])) for i in range(new_first_dim): if full_matrices_: - np_q_reshape[i,:,:], _ = \ - np.linalg.qr(x_reshape[i,:,:], mode="complete") + np_q_reshape[i, :, :], _ = np.linalg.qr( + x_reshape[i, :, :], mode="complete") else: - np_q_reshape[i,:,:], _ = \ - np.linalg.qr(x_reshape[i,:,:], mode="reduced") + np_q_reshape[i, :, :], _ = np.linalg.qr( + x_reshape[i, :, :], mode="reduced") np_q = np.reshape(np_q_reshape, q_dims) CompareOrthogonal(self, np_q, q_tf_val, min(shape_[-2:])) CheckApproximation(self, x_np, q_tf_val, r_tf_val) @@ -153,6 +154,46 @@ def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_): return Test +class QrGradOpTest(test.TestCase): + pass + + +def _GetQrGradOpTest(dtype_, shape_, full_matrices_): + + def Test(self): + np.random.seed(42) + a = np.random.uniform(low=-1.0, high=1.0, size=shape_).astype(dtype_) + if dtype_ in [np.complex64, np.complex128]: + a += 1j * np.random.uniform( + low=-1.0, high=1.0, size=shape_).astype(dtype_) + # Optimal stepsize for central difference is O(epsilon^{1/3}). + epsilon = np.finfo(dtype_).eps + delta = 0.1 * epsilon**(1.0 / 3.0) + if dtype_ in [np.float32, np.complex64]: + tol = 3e-2 + else: + tol = 1e-6 + with self.test_session(use_gpu=True): + tf_a = constant_op.constant(a) + tf_b = linalg_ops.qr(tf_a, full_matrices=full_matrices_) + for b in tf_b: + x_init = np.random.uniform( + low=-1.0, high=1.0, size=shape_).astype(dtype_) + if dtype_ in [np.complex64, np.complex128]: + x_init += 1j * np.random.uniform( + low=-1.0, high=1.0, size=shape_).astype(dtype_) + theoretical, numerical = gradient_checker.compute_gradient( + tf_a, + tf_a.get_shape().as_list(), + b, + b.get_shape().as_list(), + x_init_value=x_init, + delta=delta) + self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol) + + return Test + + if __name__ == "__main__": for dtype in np.float32, np.float64, np.complex64, np.complex128: for rows in 1, 2, 5, 10, 32, 100: @@ -168,4 +209,21 @@ if __name__ == "__main__": _AddTest(QrOpTest, "Qr", name, _GetQrOpTest(dtype, shape, full_matrices, use_static_shape)) + + # TODO(pfau): Get working with complex types. + # TODO(pfau): Get working with full_matrices when rows != cols + # TODO(pfau): Get working when rows < cols + # TODO(pfau): Get working with shapeholders (dynamic shapes) + for full_matrices in False, True: + for dtype in np.float32, np.float64: + for rows in 1, 2, 5, 10: + for cols in 1, 2, 5, 10: + if rows == cols or (not full_matrices and rows > cols): + for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10): + shape = batch_dims + (rows, cols) + name = "%s_%s_full_%s" % (dtype.__name__, + "_".join(map(str, shape)), + full_matrices) + _AddTest(QrGradOpTest, "QrGrad", name, + _GetQrGradOpTest(dtype, shape, full_matrices)) test.main() diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py index ec263591e10..8a76fe3ce55 100644 --- a/tensorflow/python/ops/linalg_grad.py +++ b/tensorflow/python/ops/linalg_grad.py @@ -81,6 +81,36 @@ def _CholeskyGrad(op, grad): return grad_a * 0.5 +@ops.RegisterGradient("Qr") +def _QrGrad(op, dq, dr): + """Gradient for Qr.""" + q, r = op.outputs + if q.dtype.is_complex: + raise NotImplementedError("QrGrad not implemented for dtype: %s" % q.dtype) + if (r.shape.ndims is None or r.shape.as_list()[-2] is None or + r.shape.as_list()[-1] is None): + raise NotImplementedError("QrGrad not implemented with dynamic shapes.") + if r.shape[-2].value != r.shape[-1].value: + raise NotImplementedError("QrGrad not implemented when ncols > nrows " + "or full_matrices is true and ncols != nrows.") + + qdq = math_ops.matmul(q, dq, adjoint_a=True) + qdq_ = qdq - _linalg.adjoint(qdq) + rdr = math_ops.matmul(r, dr, adjoint_b=True) + rdr_ = rdr - _linalg.adjoint(rdr) + tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0) + + def _TriangularSolve(x, r): + """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri.""" + return _linalg.adjoint( + linalg_ops.matrix_triangular_solve( + r, _linalg.adjoint(x), lower=False, adjoint=False)) + + grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r)) + grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r) + return grad_a + grad_b + + @ops.RegisterGradient("MatrixSolve") def _MatrixSolveGrad(op, grad): """Gradient for MatrixSolve.""" @@ -105,7 +135,7 @@ def _MatrixSolveLsGrad(op, grad): # b) Implement a symmetric rank-k update op instead of computing # x*z + transpose(x*z). This pattern occurs other places in TensorFlow. - def _overdetermined(op, grad): + def _Overdetermined(op, grad): """Gradients for the overdetermined case of MatrixSolveLs. This is the backprop for the solution to the normal equations of the first @@ -130,7 +160,7 @@ def _MatrixSolveLsGrad(op, grad): grad_b = math_ops.matmul(a, z) return (grad_a, grad_b, None) - def _underdetermined(op, grad): + def _Underdetermined(op, grad): """Gradients for the underdetermined case of MatrixSolveLs. This is the backprop for the solution to the normal equations of the second @@ -162,16 +192,16 @@ def _MatrixSolveLsGrad(op, grad): matrix_shape = op.inputs[0].get_shape()[-2:] if matrix_shape.is_fully_defined(): if matrix_shape[-2] >= matrix_shape[-1]: - return _overdetermined(op, grad) + return _Overdetermined(op, grad) else: - return _underdetermined(op, grad) + return _Underdetermined(op, grad) else: # We have to defer determining the shape to runtime and use # conditional execution of the appropriate graph. matrix_shape = array_ops.shape(op.inputs[0])[-2:] return control_flow_ops.cond(matrix_shape[-2] >= matrix_shape[-1], - lambda: _overdetermined(op, grad), - lambda: _underdetermined(op, grad)) + lambda: _Overdetermined(op, grad), + lambda: _Underdetermined(op, grad)) @ops.RegisterGradient("MatrixTriangularSolve") From c91dadb3737395de6b09f4f52596d7ce202eff8f Mon Sep 17 00:00:00 2001 From: Max Galkin Date: Fri, 20 Oct 2017 10:43:56 -0700 Subject: [PATCH 26/41] Minor change: extra logging to help understand the effects of OptimizeGraph and PruneGraph calls. PiperOrigin-RevId: 172902338 --- tensorflow/core/grappler/grappler_item_builder.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 54d60cd7aa4..3f6183b6f1e 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -450,12 +450,16 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( } // Optimize the graph (function inlining, l1 optimizations, etc). + VLOG(1) << "Number of nodes in graph before OptimizeGraph: " + << new_item->graph.node_size(); Status optimize_status = OptimizeGraph(new_item->graph, &new_item->graph, cfg); if (!optimize_status.ok()) { LOG(ERROR) << "Graph preprocessing failed: " << optimize_status; return nullptr; } + VLOG(1) << "Number of nodes in graph after OptimizeGraph: " + << new_item->graph.node_size(); if (cfg.prune_graph) { VLOG(1) << "Pruning graph..."; @@ -464,7 +468,8 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( LOG(ERROR) << "Pruning failed: " << status.error_message(); return nullptr; } - VLOG(1) << "Pruning ran succesfully."; + VLOG(1) << "Number of nodes in graph after pruning: " + << new_item->graph.node_size(); } // Validate feed, fetch and init nodes From 8f7439888c7c3ea7f188df64952cfb4f1e082ecc Mon Sep 17 00:00:00 2001 From: Akshay Agrawal Date: Fri, 20 Oct 2017 10:45:35 -0700 Subject: [PATCH 27/41] Patch dynamic_rnn to work in Eager mode PiperOrigin-RevId: 172902635 --- tensorflow/contrib/rnn/BUILD | 2 + .../rnn/python/kernel_tests/core_rnn_test.py | 342 +++++++++++------- tensorflow/python/BUILD | 1 + tensorflow/python/kernel_tests/BUILD | 2 + tensorflow/python/kernel_tests/rnn_test.py | 101 ++++-- tensorflow/python/ops/rnn.py | 76 ++-- 6 files changed, 333 insertions(+), 191 deletions(-) diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 571d299ad9e..29ba26d75dc 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -156,6 +156,7 @@ cuda_py_tests( "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", "//tensorflow/python:gradients", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", @@ -165,6 +166,7 @@ cuda_py_tests( "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/eager:context", ], shard_count = 10, ) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index 2fa033632ac..12def6dcc8a 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -25,10 +25,12 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib import rnn as rnn_lib from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as ops_lib from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl @@ -881,6 +883,7 @@ class LSTMTest(test.TestCase): # Smoke test, this should not raise an error rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) + @test_util.run_in_graph_and_eager_modes() def testDynamicRNNWithTupleStates(self): num_units = 3 input_size = 5 @@ -888,13 +891,20 @@ class LSTMTest(test.TestCase): num_proj = 4 max_length = 8 sequence_length = [4, 6] + in_graph_mode = context.in_graph_mode() with self.test_session(graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) - inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(None, input_size)) - ] + if in_graph_mode: + inputs = max_length * [ + array_ops.placeholder( + dtypes.float32, shape=(None, input_size)) + ] + else: + inputs = max_length * [ + constant_op.constant( + np.random.randn(batch_size, input_size).astype(np.float32)) + ] inputs_c = array_ops.stack(inputs) cell = rnn_cell.LSTMCell( num_units, @@ -924,21 +934,34 @@ class LSTMTest(test.TestCase): self.assertEqual(state_dynamic[0], state_dynamic.c) self.assertEqual(state_dynamic[1], state_dynamic.h) - variables_lib.global_variables_initializer().run() + if in_graph_mode: + variables_lib.global_variables_initializer().run() + input_value = np.random.randn(batch_size, input_size) + outputs_static = sess.run( + outputs_static, feed_dict={ + inputs[0]: input_value + }) + outputs_dynamic = sess.run( + outputs_dynamic, feed_dict={ + inputs[0]: input_value + }) + state_static = sess.run( + state_static, feed_dict={ + inputs[0]: input_value + }) + state_dynamic = sess.run( + state_dynamic, feed_dict={ + inputs[0]: input_value + }) - input_value = np.random.randn(batch_size, input_size) - outputs_static_v = sess.run(outputs_static, - feed_dict={inputs[0]: input_value}) - outputs_dynamic_v = sess.run(outputs_dynamic, - feed_dict={inputs[0]: input_value}) - self.assertAllEqual(outputs_static_v, outputs_dynamic_v) - - state_static_v = sess.run(state_static, - feed_dict={inputs[0]: input_value}) - state_dynamic_v = sess.run(state_dynamic, - feed_dict={inputs[0]: input_value}) - self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v)) + if in_graph_mode: + self.assertAllEqual(outputs_static, outputs_dynamic) + else: + self.assertAllEqual( + array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy()) + self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic)) + @test_util.run_in_graph_and_eager_modes() def testDynamicRNNWithNestedTupleStates(self): num_units = 3 input_size = 5 @@ -946,13 +969,20 @@ class LSTMTest(test.TestCase): num_proj = 4 max_length = 8 sequence_length = [4, 6] + in_graph_mode = context.in_graph_mode() with self.test_session(graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) - inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(None, input_size)) - ] + if in_graph_mode: + inputs = max_length * [ + array_ops.placeholder( + dtypes.float32, shape=(None, input_size)) + ] + else: + inputs = max_length * [ + constant_op.constant( + np.random.randn(batch_size, input_size).astype(np.float32)) + ] inputs_c = array_ops.stack(inputs) def _cell(i): @@ -993,20 +1023,34 @@ class LSTMTest(test.TestCase): sequence_length=sequence_length, scope=scope) - variables_lib.global_variables_initializer().run() + if in_graph_mode: + input_value = np.random.randn(batch_size, input_size) + variables_lib.global_variables_initializer().run() + outputs_static = sess.run( + outputs_static, feed_dict={ + inputs[0]: input_value + }) + outputs_dynamic = sess.run( + outputs_dynamic, feed_dict={ + inputs[0]: input_value + }) + state_static = sess.run( + nest.flatten(state_static), feed_dict={ + inputs[0]: input_value + }) + state_dynamic = sess.run( + nest.flatten(state_dynamic), feed_dict={ + inputs[0]: input_value + }) - input_value = np.random.randn(batch_size, input_size) - outputs_static_v = sess.run(outputs_static, - feed_dict={inputs[0]: input_value}) - outputs_dynamic_v = sess.run(outputs_dynamic, - feed_dict={inputs[0]: input_value}) - self.assertAllEqual(outputs_static_v, outputs_dynamic_v) - - state_static_v = sess.run(nest.flatten(state_static), - feed_dict={inputs[0]: input_value}) - state_dynamic_v = sess.run(nest.flatten(state_dynamic), - feed_dict={inputs[0]: input_value}) - self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v)) + if in_graph_mode: + self.assertAllEqual(outputs_static, outputs_dynamic) + else: + self.assertAllEqual( + array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy()) + state_static = [s.numpy() for s in nest.flatten(state_static)] + state_dynamic = [s.numpy() for s in nest.flatten(state_dynamic)] + self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic)) def _testDynamicEquivalentToStaticRNN(self, use_gpu, use_sequence_length): time_steps = 8 @@ -1015,21 +1059,22 @@ class LSTMTest(test.TestCase): input_size = 5 batch_size = 2 - input_values = np.random.randn(time_steps, batch_size, input_size) + input_values = np.random.randn(time_steps, batch_size, input_size).astype( + np.float32) if use_sequence_length: sequence_length = np.random.randint(0, time_steps, size=batch_size) else: sequence_length = None - ########### Step 1: Run static graph and generate readouts - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: - concat_inputs = array_ops.placeholder( - dtypes.float32, shape=(time_steps, batch_size, input_size)) - inputs = array_ops.unstack(concat_inputs) + in_graph_mode = context.in_graph_mode() + + # TODO(b/68017812): Eager ignores operation seeds, so we need to create a + # single cell and reuse it across the static and dynamic RNNs. Remove this + # special case once is fixed. + if not in_graph_mode: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) - cell = rnn_cell.LSTMCell( num_units, use_peepholes=True, @@ -1037,63 +1082,85 @@ class LSTMTest(test.TestCase): num_proj=num_proj, state_is_tuple=False) + ########### Step 1: Run static graph and generate readouts + with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + if in_graph_mode: + concat_inputs = array_ops.placeholder( + dtypes.float32, shape=(time_steps, batch_size, input_size)) + else: + concat_inputs = constant_op.constant(input_values) + inputs = array_ops.unstack(concat_inputs) + initializer = init_ops.random_uniform_initializer( + -0.01, 0.01, seed=self._seed) + + # TODO(akshayka): Remove special case once b/68017812 is fixed. + if in_graph_mode: + cell = rnn_cell.LSTMCell( + num_units, + use_peepholes=True, + initializer=initializer, + num_proj=num_proj, + state_is_tuple=False) + with variable_scope.variable_scope("dynamic_scope"): outputs_static, state_static = rnn.static_rnn( cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32) - feeds = {concat_inputs: input_values} + if in_graph_mode: + # Generate gradients and run sessions to obtain outputs + feeds = {concat_inputs: input_values} + # Initialize + variables_lib.global_variables_initializer().run(feed_dict=feeds) + # Generate gradients of sum of outputs w.r.t. inputs + static_gradients = gradients_impl.gradients( + outputs_static + [state_static], [concat_inputs]) + # Generate gradients of individual outputs w.r.t. inputs + static_individual_gradients = nest.flatten([ + gradients_impl.gradients(y, [concat_inputs]) + for y in [outputs_static[0], outputs_static[-1], state_static] + ]) + # Generate gradients of individual variables w.r.t. inputs + trainable_variables = ops_lib.get_collection( + ops_lib.GraphKeys.TRAINABLE_VARIABLES) + assert len(trainable_variables) > 1, ( + "Count of trainable variables: %d" % len(trainable_variables)) + # pylint: disable=bad-builtin + static_individual_variable_gradients = nest.flatten([ + gradients_impl.gradients(y, trainable_variables) + for y in [outputs_static[0], outputs_static[-1], state_static] + ]) + # Test forward pass + values_static = sess.run(outputs_static, feed_dict=feeds) + (state_value_static,) = sess.run((state_static,), feed_dict=feeds) - # Initialize - variables_lib.global_variables_initializer().run(feed_dict=feeds) + # Test gradients to inputs and variables w.r.t. outputs & final state + static_grad_values = sess.run(static_gradients, feed_dict=feeds) - # Generate gradients of sum of outputs w.r.t. inputs - static_gradients = gradients_impl.gradients( - outputs_static + [state_static], [concat_inputs]) + static_individual_grad_values = sess.run(static_individual_gradients, + feed_dict=feeds) - # Generate gradients of individual outputs w.r.t. inputs - static_individual_gradients = nest.flatten([ - gradients_impl.gradients(y, [concat_inputs]) - for y in [outputs_static[0], outputs_static[-1], state_static] - ]) - - # Generate gradients of individual variables w.r.t. inputs - trainable_variables = ops_lib.get_collection( - ops_lib.GraphKeys.TRAINABLE_VARIABLES) - assert len(trainable_variables) > 1, ("Count of trainable variables: %d" % - len(trainable_variables)) - # pylint: disable=bad-builtin - static_individual_variable_gradients = nest.flatten([ - gradients_impl.gradients(y, trainable_variables) - for y in [outputs_static[0], outputs_static[-1], state_static] - ]) - - # Test forward pass - values_static = sess.run(outputs_static, feed_dict=feeds) - (state_value_static,) = sess.run((state_static,), feed_dict=feeds) - - # Test gradients to inputs and variables w.r.t. outputs & final state - static_grad_values = sess.run(static_gradients, feed_dict=feeds) - - static_individual_grad_values = sess.run(static_individual_gradients, - feed_dict=feeds) - - static_individual_var_grad_values = sess.run( - static_individual_variable_gradients, feed_dict=feeds) + static_individual_var_grad_values = sess.run( + static_individual_variable_gradients, feed_dict=feeds) ########## Step 2: Run dynamic graph and generate readouts with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: - concat_inputs = array_ops.placeholder( - dtypes.float32, shape=(time_steps, batch_size, input_size)) - inputs = array_ops.unstack(concat_inputs) + if in_graph_mode: + concat_inputs = array_ops.placeholder( + dtypes.float32, shape=(time_steps, batch_size, input_size)) + else: + concat_inputs = constant_op.constant(input_values) initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) - cell = rnn_cell.LSTMCell( - num_units, - use_peepholes=True, - initializer=initializer, - num_proj=num_proj, - state_is_tuple=False) + # TODO(akshayka): Remove this special case once b/68017812 is + # fixed. + if in_graph_mode: + cell = rnn_cell.LSTMCell( + num_units, + use_peepholes=True, + initializer=initializer, + num_proj=num_proj, + state_is_tuple=False) with variable_scope.variable_scope("dynamic_scope"): outputs_dynamic, state_dynamic = rnn.dynamic_rnn( @@ -1104,72 +1171,83 @@ class LSTMTest(test.TestCase): dtype=dtypes.float32) split_outputs_dynamic = array_ops.unstack(outputs_dynamic, time_steps) - feeds = {concat_inputs: input_values} + if in_graph_mode: + feeds = {concat_inputs: input_values} - # Initialize - variables_lib.global_variables_initializer().run(feed_dict=feeds) + # Initialize + variables_lib.global_variables_initializer().run(feed_dict=feeds) - # Generate gradients of sum of outputs w.r.t. inputs - dynamic_gradients = gradients_impl.gradients( - split_outputs_dynamic + [state_dynamic], [concat_inputs]) + # Generate gradients of sum of outputs w.r.t. inputs + dynamic_gradients = gradients_impl.gradients( + split_outputs_dynamic + [state_dynamic], [concat_inputs]) - # Generate gradients of several individual outputs w.r.t. inputs - dynamic_individual_gradients = nest.flatten([ - gradients_impl.gradients(y, [concat_inputs]) - for y in - [split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic] - ]) + # Generate gradients of several individual outputs w.r.t. inputs + dynamic_individual_gradients = nest.flatten([ + gradients_impl.gradients(y, [concat_inputs]) + for y in + [split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic] + ]) - # Generate gradients of individual variables w.r.t. inputs - trainable_variables = ops_lib.get_collection( - ops_lib.GraphKeys.TRAINABLE_VARIABLES) - assert len(trainable_variables) > 1, ("Count of trainable variables: %d" % - len(trainable_variables)) - dynamic_individual_variable_gradients = nest.flatten([ - gradients_impl.gradients(y, trainable_variables) - for y in - [split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic] - ]) + # Generate gradients of individual variables w.r.t. inputs + trainable_variables = ops_lib.get_collection( + ops_lib.GraphKeys.TRAINABLE_VARIABLES) + assert len(trainable_variables) > 1, ( + "Count of trainable variables: %d" % len(trainable_variables)) + dynamic_individual_variable_gradients = nest.flatten([ + gradients_impl.gradients(y, trainable_variables) + for y in + [split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic] + ]) - # Test forward pass - values_dynamic = sess.run(split_outputs_dynamic, feed_dict=feeds) - (state_value_dynamic,) = sess.run((state_dynamic,), feed_dict=feeds) + # Test forward pass + values_dynamic = sess.run(split_outputs_dynamic, feed_dict=feeds) + (state_value_dynamic,) = sess.run((state_dynamic,), feed_dict=feeds) - # Test gradients to inputs and variables w.r.t. outputs & final state - dynamic_grad_values = sess.run(dynamic_gradients, feed_dict=feeds) + # Test gradients to inputs and variables w.r.t. outputs & final state + dynamic_grad_values = sess.run(dynamic_gradients, feed_dict=feeds) - dynamic_individual_grad_values = sess.run(dynamic_individual_gradients, - feed_dict=feeds) + dynamic_individual_grad_values = sess.run(dynamic_individual_gradients, + feed_dict=feeds) - dynamic_individual_var_grad_values = sess.run( - dynamic_individual_variable_gradients, feed_dict=feeds) + dynamic_individual_var_grad_values = sess.run( + dynamic_individual_variable_gradients, feed_dict=feeds) ######### Step 3: Comparisons + if not in_graph_mode: + values_static = outputs_static + values_dynamic = split_outputs_dynamic + state_value_static = state_static + state_value_dynamic = state_dynamic + self.assertEqual(len(values_static), len(values_dynamic)) for (value_static, value_dynamic) in zip(values_static, values_dynamic): self.assertAllEqual(value_static, value_dynamic) self.assertAllEqual(state_value_static, state_value_dynamic) - self.assertAllEqual(static_grad_values, dynamic_grad_values) + if in_graph_mode: - self.assertEqual( - len(static_individual_grad_values), len(dynamic_individual_grad_values)) - self.assertEqual( - len(static_individual_var_grad_values), - len(dynamic_individual_var_grad_values)) + self.assertAllEqual(static_grad_values, dynamic_grad_values) - for i, (a, b) in enumerate( - zip(static_individual_grad_values, dynamic_individual_grad_values)): - tf_logging.info("Comparing individual gradients iteration %d" % i) - self.assertAllEqual(a, b) + self.assertEqual( + len(static_individual_grad_values), + len(dynamic_individual_grad_values)) + self.assertEqual( + len(static_individual_var_grad_values), + len(dynamic_individual_var_grad_values)) - for i, (a, b) in enumerate( - zip(static_individual_var_grad_values, - dynamic_individual_var_grad_values)): - tf_logging.info("Comparing individual variable gradients iteration %d" % - i) - self.assertAllEqual(a, b) + for i, (a, b) in enumerate( + zip(static_individual_grad_values, dynamic_individual_grad_values)): + tf_logging.info("Comparing individual gradients iteration %d" % i) + self.assertAllEqual(a, b) + for i, (a, b) in enumerate( + zip(static_individual_var_grad_values, + dynamic_individual_var_grad_values)): + tf_logging.info("Comparing individual variable gradients iteration %d" % + i) + self.assertAllEqual(a, b) + + @test_util.run_in_graph_and_eager_modes() def testDynamicEquivalentToStaticRNN(self): self._testDynamicEquivalentToStaticRNN( use_gpu=False, use_sequence_length=False) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index e63c554e472..b7aa7bbf6b5 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1978,6 +1978,7 @@ py_library( ":tensor_array_ops", ":util", ":variable_scope", + "//tensorflow/python/eager:context", ], ) diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index dece290f837..e6848edc128 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -2297,6 +2297,7 @@ cuda_py_test( "//tensorflow/python:control_flow_ops", "//tensorflow/python:data_flow_grad", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", "//tensorflow/python:gradients", "//tensorflow/python:init_ops", "//tensorflow/python:nn_grad", @@ -2305,6 +2306,7 @@ cuda_py_test( "//tensorflow/python:sparse_grad", "//tensorflow/python:tensor_array_grad", "//tensorflow/python:variables", + "//tensorflow/python/eager:context", ], shard_count = 10, tags = ["no_windows"], diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index a644e6a44fa..d8f4b439e37 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -26,9 +26,12 @@ import numpy as np from tensorflow.contrib import rnn as contrib_rnn from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as ops_lib from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl @@ -82,9 +85,13 @@ class RNNTest(test.TestCase): self._seed = 23489 np.random.seed(self._seed) + @test_util.run_in_graph_and_eager_modes() def testInvalidSequenceLengthShape(self): cell = Plus1RNNCell() - inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))] + if context.in_graph_mode(): + inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))] + else: + inputs = [constant_op.constant(np.ones((3, 4)))] with self.assertRaisesRegexp(ValueError, "must be a vector"): rnn.dynamic_rnn( cell, @@ -92,45 +99,77 @@ class RNNTest(test.TestCase): dtype=dtypes.float32, sequence_length=[[4]]) + @test_util.run_in_graph_and_eager_modes() def testBatchSizeFromInput(self): cell = Plus1RNNCell() + in_graph_mode = context.in_graph_mode() # With static batch size - inputs = array_ops.placeholder(dtypes.float32, shape=(3, 4, 5)) - # - Without initial_state - outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) - self.assertEqual(3, outputs.shape[0].value) - self.assertEqual(3, state.shape[0].value) - # - With initial_state - outputs, state = rnn.dynamic_rnn( - cell, - inputs, - initial_state=array_ops.placeholder(dtypes.float32, shape=(3, 5))) - self.assertEqual(3, outputs.shape[0].value) - self.assertEqual(3, state.shape[0].value) - # Without static batch size - inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 5)) - # - Without initial_state - outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) - self.assertEqual(None, outputs.shape[0].value) - self.assertEqual(None, state.shape[0].value) - # - With initial_state - outputs, state = rnn.dynamic_rnn( - cell, - inputs, - initial_state=array_ops.placeholder(dtypes.float32, shape=(None, 5))) - self.assertEqual(None, outputs.shape[0].value) - self.assertEqual(None, state.shape[0].value) + if in_graph_mode: + inputs = array_ops.placeholder(dtypes.float32, shape=(3, 4, 5)) + initial_state = array_ops.placeholder(dtypes.float32, shape=(3, 5)) + else: + inputs = np.zeros((3, 4, 5), dtype=np.float32) + initial_state = np.zeros((3, 5), dtype=np.float32) + # - Without initial_state + outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) + if in_graph_mode: + self.assertEqual(3, outputs.shape[0].value) + self.assertEqual(3, state.shape[0].value) + else: + self.assertEqual(3, outputs.shape[0]) + self.assertEqual(3, state.shape[0]) + + # - With initial_state + outputs, state = rnn.dynamic_rnn( + cell, inputs, initial_state=initial_state) + if in_graph_mode: + self.assertEqual(3, outputs.shape[0].value) + self.assertEqual(3, state.shape[0].value) + else: + self.assertEqual(3, outputs.shape[0]) + self.assertEqual(3, state.shape[0]) + + # Without static batch size + # Tensor shapes are fully determined in Eager mode, so only run this + # test in graph mode. + if in_graph_mode: + inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 5)) + # - Without initial_state + outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) + self.assertEqual(None, outputs.shape[0].value) + self.assertEqual(None, state.shape[0].value) + # - With initial_state + outputs, state = rnn.dynamic_rnn( + cell, + inputs, + initial_state=array_ops.placeholder(dtypes.float32, shape=(None, 5))) + self.assertEqual(None, outputs.shape[0].value) + self.assertEqual(None, state.shape[0].value) + + @test_util.run_in_graph_and_eager_modes() def testScalarStateIsAccepted(self): cell = ScalarStateRNNCell() - inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) + in_graph_mode = context.in_graph_mode() + + if in_graph_mode: + inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) + else: + inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32) + with self.test_session() as sess: outputs, state = rnn.dynamic_rnn( cell, inputs, dtype=dtypes.float32, sequence_length=[4]) - outputs, state = sess.run( - [outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]}) - self.assertAllEqual(outputs, [[[1], [2], [3], [4]]]) - self.assertEqual(state, 4) + if in_graph_mode: + outputs, state = sess.run( + [outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]}) + + if in_graph_mode: + self.assertAllEqual(outputs, np.array([[[1], [2], [3], [4]]])) + self.assertEqual(state, 4) + else: + self.assertAllEqual(outputs.numpy(), np.array([[[1], [2], [3], [4]]])) + self.assertEqual(state.numpy(), 4) ######### Benchmarking RNN code diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index b174956e604..21c7ed361dc 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -27,6 +27,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -576,8 +577,9 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: - if varscope.caching_device is None: - varscope.set_caching_device(lambda op: op.device) + if context.in_graph_mode(): + if varscope.caching_device is None: + varscope.set_caching_device(lambda op: op.device) batch_size = _best_effort_input_batch_size(flat_input) if initial_state is not None: @@ -595,7 +597,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, ["Expected shape for Tensor %s is " % x.name, packed_shape, " but saw shape: ", x_shape]) - if sequence_length is not None: + if context.in_graph_mode() and sequence_length is not None: # Perform some shape validation with ops.control_dependencies( [_assert_has_shape(sequence_length, [batch_size])]): @@ -718,14 +720,19 @@ def _dynamic_rnn_loop(cell, size=time_steps, tensor_array_name=base_name + name) - output_ta = tuple(_create_ta("output_%d" % i, - _infer_state_dtype(dtype, state)) - for i in range(len(flat_output_size))) - input_ta = tuple(_create_ta("input_%d" % i, flat_input[i].dtype) - for i in range(len(flat_input))) - - input_ta = tuple(ta.unstack(input_) - for ta, input_ in zip(input_ta, flat_input)) + in_graph_mode = context.in_graph_mode() + if in_graph_mode: + output_ta = tuple(_create_ta("output_%d" % i, + _infer_state_dtype(dtype, state)) + for i in range(len(flat_output_size))) + input_ta = tuple(_create_ta("input_%d" % i, flat_input[i].dtype) + for i in range(len(flat_input))) + input_ta = tuple(ta.unstack(input_) + for ta, input_ in zip(input_ta, flat_input)) + else: + output_ta = tuple([0 for _ in range(time_steps.numpy())] + for i in range(len(flat_output_size))) + input_ta = flat_input def _time_step(time, output_ta_t, state): """Take a time step of the dynamic RNN. @@ -739,10 +746,13 @@ def _dynamic_rnn_loop(cell, The tuple (time + 1, output_ta_t with updated flow, new_state). """ - input_t = tuple(ta.read(time) for ta in input_ta) - # Restore some shape information - for input_, shape in zip(input_t, inputs_got_shape): - input_.set_shape(shape[1:]) + if in_graph_mode: + input_t = tuple(ta.read(time) for ta in input_ta) + # Restore some shape information + for input_, shape in zip(input_t, inputs_got_shape): + input_.set_shape(shape[1:]) + else: + input_t = tuple(ta[time.numpy()] for ta in input_ta) input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t) call_cell = lambda: cell(input_t, state) @@ -764,8 +774,12 @@ def _dynamic_rnn_loop(cell, # Pack state if using state tuples output = nest.flatten(output) - output_ta_t = tuple( - ta.write(time, out) for ta, out in zip(output_ta_t, output)) + if in_graph_mode: + output_ta_t = tuple( + ta.write(time, out) for ta, out in zip(output_ta_t, output)) + else: + for ta, out in zip(output_ta_t, output): + ta[time.numpy()] = out return (time + 1, output_ta_t, new_state) @@ -777,16 +791,20 @@ def _dynamic_rnn_loop(cell, swap_memory=swap_memory) # Unpack final output if not using output tuples. - final_outputs = tuple(ta.stack() for ta in output_final_ta) - - # Restore some shape information - for output, output_size in zip(final_outputs, flat_output_size): - shape = _concat( - [const_time_steps, const_batch_size], output_size, static=True) - output.set_shape(shape) + if in_graph_mode: + final_outputs = tuple(ta.stack() for ta in output_final_ta) + # Restore some shape information + for output, output_size in zip(final_outputs, flat_output_size): + shape = _concat( + [const_time_steps, const_batch_size], output_size, static=True) + output.set_shape(shape) + else: + final_outputs = output_final_ta final_outputs = nest.pack_sequence_as( structure=cell.output_size, flat_sequence=final_outputs) + if not in_graph_mode: + final_outputs = array_ops.stack(final_outputs, axis=0) return (final_outputs, final_state) @@ -967,8 +985,9 @@ def raw_rnn(cell, loop_fn, # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: - if varscope.caching_device is None: - varscope.set_caching_device(lambda op: op.device) + if context.in_graph_mode(): + if varscope.caching_device is None: + varscope.set_caching_device(lambda op: op.device) time = constant_op.constant(0, dtype=dtypes.int32) (elements_finished, next_input, initial_state, emit_structure, @@ -1166,8 +1185,9 @@ def static_rnn(cell, # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: - if varscope.caching_device is None: - varscope.set_caching_device(lambda op: op.device) + if context.in_graph_mode(): + if varscope.caching_device is None: + varscope.set_caching_device(lambda op: op.device) # Obtain the first sequence of the input first_input = inputs From 0f5683d629c6607d1baeaa44ecd264321ae05abc Mon Sep 17 00:00:00 2001 From: Jianwei Xie Date: Fri, 20 Oct 2017 10:45:51 -0700 Subject: [PATCH 28/41] Migrate the iris example to use TF core API. PiperOrigin-RevId: 172902682 --- tensorflow/examples/learn/iris.py | 99 +++++++++++++++++++++++-------- 1 file changed, 73 insertions(+), 26 deletions(-) diff --git a/tensorflow/examples/learn/iris.py b/tensorflow/examples/learn/iris.py index 33e8d458014..0a50b3ba87d 100644 --- a/tensorflow/examples/learn/iris.py +++ b/tensorflow/examples/learn/iris.py @@ -17,47 +17,94 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np -from sklearn import datasets -from sklearn import metrics -from sklearn import model_selection +import os +import urllib import tensorflow as tf +# Data sets +IRIS_TRAINING = 'iris_training.csv' +IRIS_TRAINING_URL = 'http://download.tensorflow.org/data/iris_training.csv' -X_FEATURE = 'x' # Name of the input feature. +IRIS_TEST = 'iris_test.csv' +IRIS_TEST_URL = 'http://download.tensorflow.org/data/iris_test.csv' + +FEATURE_KEYS = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width'] + + +def maybe_download_iris_data(file_name, download_url): + """Downloads the file and returns the number of data.""" + if not os.path.exists(file_name): + raw = urllib.urlopen(download_url).read() + with open(file_name, 'w') as f: + f.write(raw) + + # The first line is a comma-separated string. The first one is the number of + # total data in the file. + with open(file_name, 'r') as f: + first_line = f.readline() + num_elements = first_line.split(',')[0] + return int(num_elements) + + +def input_fn(file_name, num_data, batch_size, is_training): + """Creates an input_fn required by Estimator train/evaluate.""" + # If the data sets aren't stored locally, download them. + + def _parse_csv(rows_string_tensor): + """Takes the string input tensor and returns tuple of (features, labels).""" + # Last dim is the label. + num_features = len(FEATURE_KEYS) + num_columns = num_features + 1 + columns = tf.decode_csv(rows_string_tensor, + record_defaults=[[]] * num_columns) + features = dict(zip(FEATURE_KEYS, columns[:num_features])) + labels = tf.cast(columns[num_features], tf.int32) + return features, labels + + def _input_fn(): + """The input_fn.""" + dataset = tf.data.TextLineDataset([file_name]) + # Skip the first line (which does not have data). + dataset = dataset.skip(1) + dataset = dataset.map(_parse_csv) + + if is_training: + # For this small dataset, which can fit into memory, to achieve true + # randomness, the shuffle buffer size is set as the total number of + # elements in the dataset. + dataset = dataset.shuffle(num_data) + dataset = dataset.repeat() + + dataset = dataset.batch(batch_size) + iterator = dataset.make_one_shot_iterator() + features, labels = iterator.get_next() + return features, labels + + return _input_fn def main(unused_argv): - # Load dataset. - iris = datasets.load_iris() - x_train, x_test, y_train, y_test = model_selection.train_test_split( - iris.data, iris.target, test_size=0.2, random_state=42) + tf.logging.set_verbosity(tf.logging.INFO) + + num_training_data = maybe_download_iris_data( + IRIS_TRAINING, IRIS_TRAINING_URL) + num_test_data = maybe_download_iris_data(IRIS_TEST, IRIS_TEST_URL) # Build 3 layer DNN with 10, 20, 10 units respectively. feature_columns = [ - tf.feature_column.numeric_column( - X_FEATURE, shape=np.array(x_train).shape[1:])] + tf.feature_column.numeric_column(key, shape=1) for key in FEATURE_KEYS] classifier = tf.estimator.DNNClassifier( feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3) # Train. - train_input_fn = tf.estimator.inputs.numpy_input_fn( - x={X_FEATURE: x_train}, y=y_train, num_epochs=None, shuffle=True) - classifier.train(input_fn=train_input_fn, steps=200) + train_input_fn = input_fn(IRIS_TRAINING, num_training_data, batch_size=32, + is_training=True) + classifier.train(input_fn=train_input_fn, steps=400) - # Predict. - test_input_fn = tf.estimator.inputs.numpy_input_fn( - x={X_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False) - predictions = classifier.predict(input_fn=test_input_fn) - y_predicted = np.array(list(p['class_ids'] for p in predictions)) - y_predicted = y_predicted.reshape(np.array(y_test).shape) - - # Score with sklearn. - score = metrics.accuracy_score(y_test, y_predicted) - print('Accuracy (sklearn): {0:f}'.format(score)) - - # Score with tensorflow. + # Eval. + test_input_fn = input_fn(IRIS_TEST, num_test_data, batch_size=32, + is_training=False) scores = classifier.evaluate(input_fn=test_input_fn) print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy'])) From ff0530067435fea5c51605c2e7dfd55f6fe8dfe1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Oct 2017 11:06:03 -0700 Subject: [PATCH 29/41] Avoid silent variable sharing with ResourceVariable class. PiperOrigin-RevId: 172905986 --- tensorflow/contrib/eager/python/BUILD | 2 +- tensorflow/contrib/eager/python/saver_test.py | 13 +++----- tensorflow/python/eager/backprop_test.py | 5 +-- tensorflow/python/eager/function_test.py | 4 +-- .../resource_variable_ops_test.py | 32 ++++++++++++++----- .../python/ops/resource_variable_ops.py | 16 ++++++++++ tensorflow/python/training/saver_test.py | 3 +- 7 files changed, 53 insertions(+), 22 deletions(-) diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 702136e3e4f..ace17424fef 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -86,7 +86,7 @@ cuda_py_test( "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python/eager:graph_callable", - "//tensorflow/python:platform_test", + "//tensorflow/python/eager:test", "//tensorflow/python:variables", ], ) diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py index 29af2b531f4..c89554e6dd0 100644 --- a/tensorflow/contrib/eager/python/saver_test.py +++ b/tensorflow/contrib/eager/python/saver_test.py @@ -22,6 +22,7 @@ import os from tensorflow.contrib.eager.python import saver as _saver from tensorflow.python.eager import context from tensorflow.python.eager import graph_callable +from tensorflow.python.eager import test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -29,7 +30,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.platform import test class SaverTest(test.TestCase): @@ -38,7 +38,7 @@ class SaverTest(test.TestCase): return '/device:GPU:0' if context.num_gpus() else '/device:CPU:0' def testBasics(self): - with context.eager_mode(), ops.device(self._dev()): + with ops.device(self._dev()): v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') def model(): return array_ops.constant(2.0) * v1 @@ -55,7 +55,7 @@ class SaverTest(test.TestCase): self.assertEqual(v1.read_value().numpy(), 1.0) def testRestoreOnCreate(self): - with context.eager_mode(), ops.device(self._dev()): + with ops.device(self._dev()): def model(init_val): v1 = resource_variable_ops.ResourceVariable(init_val, name='v1') return array_ops.constant(1.0) * v1, v1 @@ -71,12 +71,9 @@ class SaverTest(test.TestCase): # Value is from checkpoint, but not from argument. ret, _ = model(2.0) self.assertEqual(ret.numpy(), 1.0) - # Create it a second time won't re-assign the checkpoint value. - v1_2 = resource_variable_ops.ResourceVariable(3.0, name='v1') - self.assertEqual(v1_2.read_value().numpy(), 3.0) def testRestoreNotFound(self): - with context.eager_mode(), ops.device(self._dev()): + with ops.device(self._dev()): def model(v): return array_ops.constant(1.0) * v @@ -92,7 +89,7 @@ class SaverTest(test.TestCase): _ = model(resource_variable_ops.ResourceVariable(1.0, name='v2')) def testSaveRestoreGraphCallable(self): - with context.eager_mode(), ops.device(self._dev()): + with ops.device(self._dev()): @graph_callable.graph_callable( [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) def model(x): diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 7da8eb0c9b5..9ba5913c65e 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -292,7 +292,7 @@ class BackpropTest(test.TestCase): self.assertEqual(grad.numpy(), 6.0) def testGradientTapeVariable(self): - v = resource_variable_ops.ResourceVariable(1.0) + v = resource_variable_ops.ResourceVariable(1.0, name='v') with backprop.GradientTape() as g: y = v * v grad = g.gradient(y, [v])[0] @@ -457,7 +457,8 @@ class BackpropTest(test.TestCase): add_n.append(1) context.context().add_post_execution_callback(callback) - v = resource_variable_ops.ResourceVariable(constant_op.constant(2.0)) + v = resource_variable_ops.ResourceVariable(constant_op.constant(2.0), + name='v') def fn(): outputs = [] for _ in range(20): diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index a4c351e8c9f..33bedb59f3a 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -57,7 +57,7 @@ class FunctionTest(test.TestCase): self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) def testGraphModeWithGradients(self): - v = resource_variable_ops.ResourceVariable(1.0) + v = resource_variable_ops.ResourceVariable(1.0, name='v') @function.defun def step(): @@ -156,7 +156,7 @@ class FunctionTest(test.TestCase): g(constant_op.constant(1.0)) def testGradientTensorConversionWithDefun(self): - three = resource_variable_ops.ResourceVariable(3.0) + three = resource_variable_ops.ResourceVariable(3.0, name='v') @function.defun def f(x): diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index cf4b61674fc..10f9a72c7bb 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -181,7 +181,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): @test_util.run_in_graph_and_eager_modes() def testInitFnDtype(self): v = resource_variable_ops.ResourceVariable( - initial_value=lambda: 1, dtype=dtypes.float32) + initial_value=lambda: 1, dtype=dtypes.float32, name="var0") self.assertEqual(dtypes.float32, v.value().dtype) @test_util.run_in_graph_and_eager_modes() @@ -192,26 +192,27 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): @test_util.run_in_graph_and_eager_modes() def testInitializeAllVariables(self): - v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.float32) + v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.float32, + name="var0") self.evaluate(variables.global_variables_initializer()) self.assertEqual(1.0, self.evaluate(v.value())) @test_util.run_in_graph_and_eager_modes() def testOperatorOverload(self): - v = resource_variable_ops.ResourceVariable(1.0) + v = resource_variable_ops.ResourceVariable(1.0, name="var0") self.evaluate(variables.global_variables_initializer()) self.assertEqual(2.0, self.evaluate(v + v)) @test_util.run_in_graph_and_eager_modes() def testAssignMethod(self): - v = resource_variable_ops.ResourceVariable(1.0) + v = resource_variable_ops.ResourceVariable(1.0, name="var0") self.evaluate(variables.global_variables_initializer()) self.evaluate(v.assign(2.0)) self.assertEqual(2.0, self.evaluate(v.value())) @test_util.run_in_graph_and_eager_modes() def testLoad(self): - v = resource_variable_ops.ResourceVariable(1.0) + v = resource_variable_ops.ResourceVariable(1.0, name="var0") self.evaluate(variables.global_variables_initializer()) v.load(2.0) self.assertEqual(2.0, self.evaluate(v.value())) @@ -237,21 +238,21 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): @test_util.run_in_graph_and_eager_modes() def testAssignAddMethod(self): - v = resource_variable_ops.ResourceVariable(1.0) + v = resource_variable_ops.ResourceVariable(1.0, name="var0") self.evaluate(variables.global_variables_initializer()) self.evaluate(v.assign_add(1.0)) self.assertEqual(2.0, self.evaluate(v.value())) @test_util.run_in_graph_and_eager_modes() def testAssignSubMethod(self): - v = resource_variable_ops.ResourceVariable(3.0) + v = resource_variable_ops.ResourceVariable(3.0, name="var0") self.evaluate(variables.global_variables_initializer()) self.evaluate(v.assign_sub(1.0)) self.assertEqual(2.0, self.evaluate(v.value())) @test_util.run_in_graph_and_eager_modes() def testDestroyResource(self): - v = resource_variable_ops.ResourceVariable(3.0) + v = resource_variable_ops.ResourceVariable(3.0, name="var0") self.evaluate(variables.global_variables_initializer()) self.assertEqual(3.0, self.evaluate(v.value())) self.evaluate(resource_variable_ops.destroy_resource_op(v.handle)) @@ -443,6 +444,21 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): resource_variable_ops.destroy_resource_op(var._handle, ignore_lookup_error=False) + def testSharingViaResourceVariableObject(self): + with context.eager_mode(): + _ = resource_variable_ops.ResourceVariable(1.0, name="var0") + with self.assertRaisesRegexp(ValueError, + "'var0' already created"): + _ = resource_variable_ops.ResourceVariable(2.0, name="var0") + with ops.Graph().as_default(): + _ = resource_variable_ops.ResourceVariable(2.0, name="var0") + + def testVariableNameMissing(self): + with context.eager_mode(): + with self.assertRaisesRegexp(ValueError, + "Variables need to have explicit names"): + _ = resource_variable_ops.ResourceVariable(1.0) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index aa45752a9d4..c94ddb06275 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -49,6 +49,16 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): container=container) if graph_mode: return handle + + # We do not want two distinct ResourceVariable objects for the same + # underlying resource in the runtime. + # When in eager mode, explicitly ensure so here. When in graph mode, it's + # ensured by always generating different variable names. + exists = gen_resource_variable_ops.var_is_initialized_op(handle) + if exists: + raise ValueError("variable object with name '%s' already created. Use " + "get_variable() if reuse is desired." % + shared_name) with context.graph_mode(), ops.Graph().as_default(): h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, @@ -273,6 +283,12 @@ class ResourceVariable(variables.Variable): # Save the graph's container prefix for error checking. Reading the value of # the ResourceVariable from another Graph in Eager mode is an error. self._container_prefix = ops.get_default_graph()._container_prefix # pylint: disable=protected-access + if not self._in_graph_mode and not name: + # TODO(ashankar,josh11b): make this unnecessary using the same + # logic as in layer + raise ValueError("Variables need to have explicit names when eager " + "execution is enabled") + with ops.control_dependencies(None): with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index aeb8eaffe87..4abff1d106a 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -233,7 +233,8 @@ class SaverTest(test.TestCase): def testResourceSaveRestoreCachingDevice(self): save_path = os.path.join(self.get_temp_dir(), "resource_cache") with self.test_session(graph=ops_lib.Graph()) as sess: - v = resource_variable_ops.ResourceVariable([1], caching_device="/cpu:0") + v = resource_variable_ops.ResourceVariable([1], caching_device="/cpu:0", + name="v") if context.in_graph_mode(): self.evaluate(variables.global_variables_initializer()) else: From 017a5021a7fdc713357fceecf31068ae5090afaf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Oct 2017 11:13:03 -0700 Subject: [PATCH 30/41] [XLA:CPU] Do not assign parallel tasks to instructions which forward pointers (GetTupleElement and Bitcast), because the process of outlining the instruction into a parallel computation forces the pointed-to buffer to be materialized. PiperOrigin-RevId: 172907063 --- tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 5afb2e67fff..c2213c8f2ef 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -136,6 +136,8 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( instruction->opcode() == HloOpcode::kCall || instruction->opcode() == HloOpcode::kCustomCall || instruction->opcode() == HloOpcode::kSelectAndScatter || + instruction->opcode() == HloOpcode::kGetTupleElement || + instruction->opcode() == HloOpcode::kBitcast || (instruction->opcode() == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction)) || PotentiallyImplementedAsEigenDot(*instruction) || From 86908c30c4c0adf92fa14ed6f1d92616177c1b89 Mon Sep 17 00:00:00 2001 From: Jianwei Xie Date: Fri, 20 Oct 2017 11:13:45 -0700 Subject: [PATCH 31/41] Step 1: Large refactoring toward wrapping input_fn and TPU infeed into tf.while_loop PiperOrigin-RevId: 172907182 --- .../contrib/tpu/python/tpu/tpu_estimator.py | 1287 +++++++++-------- 1 file changed, 651 insertions(+), 636 deletions(-) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 04e0719a1be..805de16468a 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import collections +from contextlib import contextmanager import copy import threading import six @@ -38,6 +39,7 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import util +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -57,12 +59,15 @@ from tensorflow.python.training import training_util _INITIAL_LOSS = 1e7 _ZERO_LOSS = 0. -_DEFAULT_NAME_SCOPE = 'tpu_estimator' +_TPU_ESTIMATOR = 'tpu_estimator' _ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop' _BATCH_SIZE_KEY = 'batch_size' _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY] +# TODO(b/65703635): Flip the value and remove all dead code. +_WRAP_INPUT_FN_INTO_WHILE_LOOP = False + def _create_global_step(graph): graph = graph or ops.get_default_graph() @@ -81,17 +86,25 @@ def _create_global_step(graph): ops.GraphKeys.GLOBAL_STEP]) -def _create_iterations_per_loop(): - with variable_scope.variable_scope(_DEFAULT_NAME_SCOPE, - reuse=variable_scope.AUTO_REUSE): - return variable_scope.get_variable( - _ITERATIONS_PER_LOOP_VAR, - initializer=init_ops.zeros_initializer(), - shape=[], - dtype=dtypes.int32, - trainable=False, - collections=[], - use_resource=True) +def _create_or_get_iterations_per_loop(): + graph = ops.get_default_graph() + iter_vars = graph.get_collection(_TPU_ESTIMATOR) + if len(iter_vars) == 1: + return iter_vars[0] + elif len(iter_vars) > 1: + raise RuntimeError('Multiple iterations_per_loop_var in collection.') + + with ops.colocate_with(training_util.get_global_step()): + with variable_scope.variable_scope(_TPU_ESTIMATOR, + reuse=variable_scope.AUTO_REUSE): + return variable_scope.get_variable( + _ITERATIONS_PER_LOOP_VAR, + initializer=init_ops.zeros_initializer(), + shape=[], + dtype=dtypes.int32, + trainable=False, + collections=[_TPU_ESTIMATOR], + use_resource=True) def _sync_variables_ops(): @@ -127,64 +140,209 @@ _DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' _LOCAL_MASTERS = ('', 'local') -def _tpu_job(run_config, mode): - """Returns the job name to use to place TPU computations on. +class _TPUContext(object): + """A context holds immutable states of TPU computation. - Args: - run_config: The tpu_config.RunConfig used for this custom estimator. - mode: A model_fn_lib.ModeKeys value. + This immutable object holds TPUEstimator config, train/eval batch size, and + `TPUEstimator.use_tpu`, which is expected to be passed around. It also + provides utility functions, basded on the current state, to determine other + information commonly required by TPU computation, such as TPU device names, + TPU hosts, shard batch size, etc. - Returns: - A string containing the job name, or None if no job should be specified. - - Raises: - ValueError: If the user needs to specify a tpu_job_name, because we are - unable to infer the job name automatically, or if the user-specified job - names are inappropriate. + N.B. As `mode` is not immutable state in Estimator, but essential to + distinguish between TPU training and evaluation, a common usage for + _TPUContext with `mode` is as follows: + ``` + with _ctx.with_mode(mode) as ctx: + if ctx.is_running_on_cpu(): + ... + ``` """ - # If the user specifies the tpu_job_name, use that. - if run_config.tpu_config.tpu_job_name: - return run_config.tpu_config.tpu_job_name - # The tpu job is determined by the run_config. Right now, this method is - # required as tpu_config is not part of the RunConfig. - master = (run_config.evaluation_master if mode == model_fn_lib.ModeKeys.EVAL - else run_config.master) - if master in _LOCAL_MASTERS: - return None + def __init__(self, config, train_batch_size, eval_batch_size, use_tpu): + self._config = config + self._train_batch_size = train_batch_size + self._eval_batch_size = eval_batch_size + self._use_tpu = use_tpu + self._num_shards_or_none = self._config.tpu_config.num_shards + self._mode = None - if (not run_config.session_config or - not run_config.session_config.cluster_def.job): - return _DEFAULT_JOB_NAME - cluster_def = run_config.session_config.cluster_def - job_names = set([job.name for job in cluster_def.job]) - if _DEFAULT_JOB_NAME in job_names: - # b/37868888 tracks allowing ClusterSpec propagation to reuse job names. - raise ValueError('Currently, tpu_worker is not an allowed job name.') - if len(job_names) == 1: - return cluster_def.job[0].name - if len(job_names) == 2: - if _DEFAULT_COORDINATOR_JOB_NAME in job_names: - job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME) - return job_names.pop() - # TODO(b/67716447): Include more sophisticated heuristics. - raise ValueError( - 'Could not infer TPU job name. Please specify a tpu_job_name as part of ' - 'your TPUConfig.') + def _assert_mode(self): + if self._mode is None: + raise RuntimeError( + '`mode` needs to be set via contextmanager `with_mode`.') + return self._mode + @property + def num_of_cores_per_host(self): + num_cores = self.num_cores + return min(num_cores, 8) -def _is_running_on_cpu(use_tpu, mode, eval_batch_size): - """Determines whether the input_fn and model_fn should be invoked on CPU.""" - return ((not use_tpu) or mode == model_fn_lib.ModeKeys.PREDICT or - (mode == model_fn_lib.ModeKeys.EVAL and eval_batch_size is None)) + @contextmanager + def with_mode(self, mode): + new_ctx = copy.copy(self) # Shallow copy is enough. + new_ctx._mode = mode # pylint: disable=protected-access + yield new_ctx + @property + def mode(self): + return self._assert_mode() -def _per_shard_batch_size(global_batch_size, run_config, use_tpu): - """Returns the batch size for each shard.""" - if use_tpu: - return global_batch_size // run_config.tpu_config.num_shards - else: - return global_batch_size + @property + def num_cores(self): + # TODO(xiejw): Adds lazy num_shards initialization. + return self._num_shards_or_none + + @property + def num_hosts(self): + return self.num_cores // self.num_of_cores_per_host + + @property + def config(self): + return self._config + + def is_input_sharded_per_core(self): + """Return true if input_fn is invoked per-core (other than per-host).""" + self._assert_mode() + return (self._mode == model_fn_lib.ModeKeys.TRAIN and + not self._config.tpu_config.per_host_input_for_training) + + def is_running_on_cpu(self): + """Determines whether the input_fn and model_fn should be invoked on CPU.""" + mode = self._assert_mode() + return ((not self._use_tpu) or mode == model_fn_lib.ModeKeys.PREDICT or + (mode == model_fn_lib.ModeKeys.EVAL and + self._eval_batch_size is None)) + + @property + def batch_size_for_input_fn(self): + """Returns the shard batch size for `input_fn`.""" + mode = self._assert_mode() + # Special case for eval. + if mode == model_fn_lib.ModeKeys.EVAL and self._eval_batch_size is None: + return None + if self.is_running_on_cpu(): + if mode == model_fn_lib.ModeKeys.TRAIN: + return self._train_batch_size + if mode == model_fn_lib.ModeKeys.EVAL: + return self._eval_batch_size + return None + + global_batch_size = (self._train_batch_size if + mode == model_fn_lib.ModeKeys.TRAIN + else self._eval_batch_size) + # On TPU + return (global_batch_size // self.num_cores + if self.is_input_sharded_per_core() else global_batch_size) + + @property + def batch_size_for_model_fn(self): + """Returns the shard batch size for `model_fn`.""" + mode = self._assert_mode() + # Special case for eval. + if mode == model_fn_lib.ModeKeys.EVAL and self._eval_batch_size is None: + return None + if self.is_running_on_cpu(): + if mode == model_fn_lib.ModeKeys.TRAIN: + return self._train_batch_size + if mode == model_fn_lib.ModeKeys.EVAL: + return self._eval_batch_size + return None + + # On TPU. always sharded per core. + if mode == model_fn_lib.ModeKeys.TRAIN: + return self._train_batch_size // self.num_cores + else: + return self._eval_batch_size // self.num_cores + + @property + def master_job(self): + """Returns the job name to use to place TPU computations on. + + Returns: + A string containing the job name, or None if no job should be specified. + + Raises: + ValueError: If the user needs to specify a tpu_job_name, because we are + unable to infer the job name automatically, or if the user-specified job + names are inappropriate. + """ + run_config = self._config + # If the user specifies the tpu_job_name, use that. + if run_config.tpu_config.tpu_job_name: + return run_config.tpu_config.tpu_job_name + + # The tpu job is determined by the run_config. Right now, this method is + # required as tpu_config is not part of the RunConfig. + mode = self._assert_mode() + master = (run_config.evaluation_master if mode == model_fn_lib.ModeKeys.EVAL + else run_config.master) + if master in _LOCAL_MASTERS: + return None + + if (not run_config.session_config or + not run_config.session_config.cluster_def.job): + return _DEFAULT_JOB_NAME + cluster_def = run_config.session_config.cluster_def + job_names = set([job.name for job in cluster_def.job]) + if _DEFAULT_JOB_NAME in job_names: + # b/37868888 tracks allowing ClusterSpec propagation to reuse job names. + raise ValueError('Currently, tpu_worker is not an allowed job name.') + if len(job_names) == 1: + return cluster_def.job[0].name + if len(job_names) == 2: + if _DEFAULT_COORDINATOR_JOB_NAME in job_names: + job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME) + return job_names.pop() + # TODO(b/67716447): Include more sophisticated heuristics. + raise ValueError( + 'Could not infer TPU job name. Please specify a tpu_job_name as part ' + 'of your TPUConfig.') + + @property + def tpu_host_placement_function(self): + """Returns the TPU host place function.""" + master = self.master_job + def _placement_function(_sentinal=None, core_id=None, host_id=None): # pylint: disable=invalid-name + assert _sentinal is None + if core_id is not None and host_id is not None: + raise RuntimeError( + 'core_id and host_id can have only one non-None value.') + + if master is None: + return '/replica:0/task:0/device:CPU:0' + else: + # This assumes that if using more than 8 shards, + # the job configuration varies 'task'. + if core_id is not None: + host_id = core_id / 8 + return '/job:%s/task:%d/device:CPU:0' % (master, host_id) + return _placement_function + + @property + def tpu_device_placement_function(self): + master = self.master_job + job_device = '' if master is None else ('/job:%s' % master) + def _placement_function(i): + return '%s/task:%d/device:TPU:%d' % (job_device, i / 8, i % 8) + return _placement_function + + @property + def tpu_ordinal_function(self): + """Returns the TPU ordinal fn.""" + def _tpu_ordinal_function(index): + """Return the TPU ordinal associated with a shard. + + Required because the enqueue ops are placed on CPU. + + Args: + index: the shard index + + Returns: + The ordinal of the TPU device the shard's infeed should be placed on. + """ + return index % 8 + return _tpu_ordinal_function class _SIGNAL(object): @@ -319,11 +477,16 @@ class _InfeedThreadController(_InfeedOutfeedThreadBaseController): logging.info('Stop Infeed input thread.') return - iterations = signal - for i in range(iterations): - logging.debug('Infeed enqueue for iteration (%d, %d)', count, i) + if _WRAP_INPUT_FN_INTO_WHILE_LOOP: + # Enqueue batches for next loop. session.run(enqueue_ops) - count += 1 + else: + iterations = signal + for i in range(iterations): + logging.debug('Infeed enqueue for iteration (%d, %d)', count, i) + session.run(enqueue_ops) + count += 1 + except Exception: # pylint: disable=broad-except logging.error( 'Failed running infeed, closing session.\n' @@ -346,17 +509,16 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): dequeue. """ - def __init__(self, run_config, mode, enqueue_fn, dequeue_ops=None): - self._tpu_job = _tpu_job(run_config, mode) - self._enqueue_fn = enqueue_fn + def __init__(self, ctx, enqueue_ops, dequeue_ops=None): + self._master_job = ctx.master_job + self._enqueue_ops = enqueue_ops self._dequeue_ops = dequeue_ops def begin(self): - self._enqueue_ops = self._enqueue_fn() - self._iterations_per_loop_var = _create_iterations_per_loop() - logging.info('TPU job name %s', self._tpu_job) - self._init_op = [tpu.initialize_system(job=self._tpu_job)] - self._finalize_op = [tpu.shutdown_system(job=self._tpu_job)] + logging.info('TPU job name %s', self._master_job) + self._iterations_per_loop_var = _create_or_get_iterations_per_loop() + self._init_op = [tpu.initialize_system(job=self._master_job)] + self._finalize_op = [tpu.shutdown_system(job=self._master_job)] def after_create_session(self, session, coord): logging.info('Init TPU system') @@ -378,6 +540,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): iterations = run_context.session.run(self._iterations_per_loop_var) self._infeed_thd_controller.send_next_batch_signal(iterations) if self._dequeue_ops is not None: + # TODO(xiejw): Refactor the outfeed dequeue into tf.while_loop. logging.info('Dequeue next batch of data from outfeed.') self._outfeed_thd_controller.send_next_batch_signal(iterations) @@ -439,7 +602,7 @@ class _TPUStopAtStepHook(session_run_hook.SessionRunHook): if self._global_step_tensor is None: raise RuntimeError('Global step should be created.') - self._iterations_per_loop_var = _create_iterations_per_loop() + self._iterations_per_loop_var = _create_or_get_iterations_per_loop() def after_create_session(self, session, coord): global_step = session.run(self._global_step_tensor) @@ -474,360 +637,288 @@ class _SetEvalIterationsHook(session_run_hook.SessionRunHook): self._num_steps = num_steps def begin(self): - self._iterations_per_loop_var = _create_iterations_per_loop() + self._iterations_per_loop_var = _create_or_get_iterations_per_loop() def after_create_session(self, session, coord): self._iterations_per_loop_var.load(self._num_steps, session=session) -class _PerShardOutput(object): - """Wraps input_fn's outputs into per-shard outputs. +def generate_per_core_enqueue_ops_fn_for_host( + ctx, input_fn, inputs_structure_recorder): + """Generates infeed enqueue ops for per-core input_fn on a single host.""" + infeed_queue_holder = {'instance': None} - Used so that the model_fn can distinguish between sharded input and unsharded - inputs (e.g., for export_savedmodel()). - """ + def enqueue_ops_fn(): + """A fn returns enqueue_ops.""" + num_cores_per_host = ctx.num_of_cores_per_host + per_host_sharded_inputs = [] + for core_ordinal in range(num_cores_per_host): + with ops.name_scope('ordinal_%d' % (core_ordinal)): + inputs = input_fn() + if isinstance(inputs, tuple): + features, labels = inputs + else: + features, labels = inputs, None - def __init__(self, output): - self.output = output + inputs_structure_recorder.validate_and_record_structure( + features, labels) + flattened_inputs = ( + inputs_structure_recorder.flatten_features_and_labels( + features, labels)) + per_host_sharded_inputs.append(flattened_inputs) - def as_list(self): - return self.output + infeed_queue = tpu_feed.InfeedQueue( + number_of_tuple_elements=len(per_host_sharded_inputs[0])) + infeed_queue_holder['instance'] = infeed_queue + infeed_queue.set_configuration_from_sharded_input_tensors( + per_host_sharded_inputs) + + per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( + per_host_sharded_inputs, + tpu_ordinal_function=ctx.tpu_ordinal_function) + return per_host_enqueue_ops + return enqueue_ops_fn, (lambda: infeed_queue_holder['instance']) -class _InputsHolder(object): - """A inputs holder holds the `features` and `labels' for TPU system. +class _InputPipeline(object): + """`_InputPipeline` handles invoking `input_fn` and piping to infeed queue. - Model inputs returned by the `input_fn` can have one of the following forms: + `_InputPipeline` abstracts the per-core/per-host `input_fn` invocation from + call site. To be precise, based on the configuration in `_TPUContext`, it + invokes `input_fn` for all cores (usually multi-host TPU training) or for one + host (usually for single-host TPU evaluation), and sends all `features` and + `labels` returned by `input_fn` to TPU infeed. For per-core invocation, + `features` and `labels` are piped to infeed directly, one tuple for each + core. For per-host invocation, `features` and `labels` are split at host + (with respect to `batch_axis`) and piped to all cores accordingly. + + In addition, flatten/unflatten are handled by `_InputPipeline` also. Model + inputs returned by the `input_fn` can have one of the following forms: 1. features 2. (features, labels) Internally, form 1 is reformed to `(features, None)` as features and labels are passed separatedly to underlying methods. For TPU training, TPUEstimator - expects multiple `features` and `labels` tuples one for each shard. + may expect multiple `features` and `labels` tuples one for each core. - In addition, TPUEstimator allows various different structures for inputs - (namely `features` and `labels`). `features` can be `Tensor` or dict of - string name to `Tensor`, and `labels` could be `None`, `Tensor`, or dict of - string name to `Tensor`. TPU infeed/outfeed library expects flattened tensor - list. So, `features` and `labels` need to be flattened, before infeed enqueue, - and the structure of them needs to be recorded, in order to restore them after - infeed dequeue. - - `_InputsHolder` could hold the `features` and `labels` tuple for all shards - (usually multi-host TPU training) or for one host (usually for single-host TPU - evaluation), records the structure details (including presence, dict or single - tensor, dict names), validates the structure consistency cross all shards, and - encapsulates the flatten/unflatten logic. + TPUEstimator allows various different structures for inputs (namely `features` + and `labels`). `features` can be `Tensor` or dict of string name to `Tensor`, + and `labels` could be `None`, `Tensor`, or dict of string name to `Tensor`. + TPU infeed/outfeed library expects flattened tensor list. So, `features` and + `labels` need to be flattened, before infeed enqueue, and the structure of + them needs to be recorded, in order to restore them after infeed dequeue. """ - def __init__(self, features=None, labels=None, num_shards=None): + class InputsStructureRecorder(object): + """The recorder to record inputs structure.""" + + def __init__(self): + # Holds the structure of inputs + self._feature_names = [] + self._label_names = [] + self._has_labels = False + + # Internal state. + self._initialized = False + + def has_labels(self): + return self._has_labels + + def validate_and_record_structure(self, features, labels): + """Validates and records the structure of features` and `labels`.""" + def _extract_key_names(tensor_or_dict): + if tensor_or_dict is None: + return [] + return tensor_or_dict.keys() if isinstance(tensor_or_dict, dict) else [] + + # Extract structure. + has_labels = labels is not None + feature_names = _extract_key_names(features) + label_names = _extract_key_names(labels) + + if self._initialized: + # Verify the structure is same. The following should never happen. + assert feature_names == self._feature_names, 'feature keys mismatched' + assert label_names == self._label_names, 'label keys mismatched' + assert has_labels == self._has_labels, 'label presence mismatched' + else: + # Record structure. + self._initialized = True + self._feature_names = feature_names + self._label_names = label_names + self._has_labels = has_labels + + def flatten_features_and_labels(self, features, labels): + """Flattens the `features` and `labels` to a single tensor list.""" + flattened_inputs = [] + if self._feature_names: + # We need a fixed ordering for enqueueing and dequeueing. + flattened_inputs.extend([features[name] + for name in self._feature_names]) + else: + flattened_inputs.append(features) + + if labels is not None: + if self._label_names: + # We need a fixed ordering for enqueueing and dequeueing. + flattened_inputs.extend([labels[name] for name in self._label_names]) + else: + flattened_inputs.append(labels) + return flattened_inputs + + def unflatten_features_and_labels(self, flattened_inputs): + """Restores the flattened inputs to original features and labels form. + + Args: + flattened_inputs: Flattened inputs for each shard. + + Returns: + A tuple of (`features`, `labels`), where `labels` could be None. + Each one, if present, should have identical structure (single tensor vs + dict) as the one returned by input_fn. + + Raises: + ValueError: If the number of expected tensors from `flattened_inputs` + mismatches the recorded structure. + """ + expected_num_features = (len(self._feature_names) if self._feature_names + else 1) + if self._has_labels: + expected_num_labels = (len(self._label_names) if self._label_names + else 1) + else: + expected_num_labels = 0 + + expected_num_tensors = expected_num_features + expected_num_labels + + if expected_num_tensors != len(flattened_inputs): + raise ValueError( + 'The number of flattened tensors mismatches expected num. ' + 'Expected {}, got {}'.format(expected_num_tensors, + len(flattened_inputs))) + if self._feature_names: + unflattened_features = dict( + zip(self._feature_names, flattened_inputs[:expected_num_features])) + else: + # Single tensor case + unflattened_features = flattened_inputs[0] + + if expected_num_labels == 0: + unflattened_label = None + elif self._label_names: + unflattened_label = dict(zip(self._label_names, + flattened_inputs[expected_num_features:])) + else: + # Single tensor case. + unflattened_label = flattened_inputs[expected_num_features] + + return unflattened_features, unflattened_label + + def __init__(self, input_fn, batch_axis, ctx): """Constructor. Args: - features: features for one host or a list of features one for each shard - (must be type `_PerShardOutput`). Once provided, the corresponding - `labels` should be set also and this `_InputsHolder` is frozen to - prevent from future modification. If `None`, it is expected to add - features and labels for each shard by calling `append_tuple` later. - labels: labels for one host or a list of labels one for each shard - (must be type `_PerShardOutput`). - num_shards: Number of shards in the TPU system. Must be provided unless it - can be deduced from `features`. + input_fn: input fn for train or eval. + batch_axis: A python tuple of int values describing how each tensor + produced by the Estimator `input_fn` should be split across the TPU + compute shards. + ctx: A `_TPUContext` instance with mode. Raises: - ValueError: If both `sharded_features` and `num_shards` are `None`. + ValueError: If both `sharded_features` and `num_cores` are `None`. """ - # Holds the features and labels for all shards. - self._feature_list = [] - self._label_list = [] + self._inputs_structure_recorder = _InputPipeline.InputsStructureRecorder() - # Holds the structure of inputs - self._feature_names = [] - self._label_names = [] - self._has_labels = False + self._sharded_per_core = ctx.is_input_sharded_per_core() + self._input_fn = input_fn + self._infeed_queue = None + self._ctx = ctx + self._batch_axis = batch_axis - # Internal state. - self._initialized = False - self._frozen = False - self._sharded = False + def generate_infeed_enqueue_ops_and_dequeue_fn(self): + """Generates infeed enqueue ops and dequeue_fn.""" + # While tf.while_loop is called, the body function, which invokes + # `enqueue_fn` passed in, is called to construct the graph. So, input_fn + # structure is recorded. + enqueue_ops = self._invoke_input_fn_and_record_structure() + + def dequeue_fn(): + """dequeue_fn is used by TPU to retrieve the tensors.""" + values = self._infeed_queue.generate_dequeue_op() + # The unflatten process uses the structure information recorded above. + return self._inputs_structure_recorder.unflatten_features_and_labels( + values) + + return (enqueue_ops, dequeue_fn) + + def _invoke_input_fn_and_record_structure(self): + if self._sharded_per_core: + # Per-Core input pipeline deployment. + tpu_host_placement_fn = self._ctx.tpu_host_placement_function + enqueue_ops = [] + infeed_queues = [] + + # Invoke input pipeline for each core and placed on the corresponding + # host. + num_hosts = self._ctx.num_hosts + for host_id in range(num_hosts): + host_device = tpu_host_placement_fn(host_id=host_id) + with ops.device(host_device): + with ops.name_scope('input_pipeline_task%d' % (host_id)): + enqueue_ops_fn, infeed_queue_getter = ( + generate_per_core_enqueue_ops_fn_for_host( + self._ctx, self._input_fn, self._inputs_structure_recorder)) + + if _WRAP_INPUT_FN_INTO_WHILE_LOOP: + enqueue_ops.append(_wrap_computation_in_while_loop( + device=host_device, op_fn=enqueue_ops_fn)) + else: + enqueue_ops.append(enqueue_ops_fn()) + # Infeed_queue_getter must be called after enqueue_ops_fn is called. + infeed_queues.append(infeed_queue_getter()) + + # infeed_queue is used to generate dequeue ops. The only thing it uses for + # dequeue is dtypes and types. So, any one can be used. Here, grab the + # first one. + self._infeed_queue = infeed_queues[0] + return enqueue_ops - if features is None: - if num_shards is None: - raise ValueError( - '`features` and `num_shards` cannot be both None') - self._num_shards = num_shards - elif isinstance(features, _PerShardOutput): - self._from_sharded_inputs(features, labels, num_shards) else: - if num_shards is None: - raise ValueError( - '`num_shards` cannot be None for unsharded features.') - self._from_unsharded_inputs(features, labels, num_shards) + # TODO(b/67051042): Extend this to multi-host support. + host_id = 0 + host_device = self._ctx.tpu_host_placement_function(host_id=host_id) + def enqueue_fn(): + with ops.device(host_device): + with ops.name_scope('input_pipeline_task%d' % (host_id)): + inputs = self._input_fn() + if isinstance(inputs, tuple): + features, labels = inputs + else: + features, labels = inputs, None + self._inputs_structure_recorder.validate_and_record_structure( + features, labels) + unsharded_tensor_list = ( + self._inputs_structure_recorder.flatten_features_and_labels( + features, labels)) - def _from_unsharded_inputs(self, features, labels, num_shards): - """Initializes the inputs with unsharded features and labels.""" - self._num_shards = num_shards - if labels is not None: - self._has_labels = True - self.append_tuple((features, labels)) - else: - self.append_tuple(features) + self._infeed_queue = tpu_feed.InfeedQueue( + tuple_types=[t.dtype for t in unsharded_tensor_list], + tuple_shapes=[t.shape for t in unsharded_tensor_list], + shard_dimensions=self._batch_axis) + self._infeed_queue.set_number_of_shards(self._ctx.num_cores) - self._sharded = False - self._frozen = True + def placement_fn(core_id): + return self._ctx.tpu_host_placement_function(core_id=core_id) + return ( + self._infeed_queue.split_inputs_and_generate_enqueue_ops( + unsharded_tensor_list, + placement_function=placement_fn)) - def _from_sharded_inputs(self, sharded_features, sharded_labels, num_shards): - """Initializes the inputs with sharded features and labels.""" - if not isinstance(sharded_features, _PerShardOutput): - raise ValueError('`sharded_features` must have type `_PerShardOutput`.') - features = sharded_features.as_list() - - if num_shards is not None and num_shards != len(features): - raise ValueError( - '`num_shards` should be same as the length of sharded_features.') - - self._num_shards = len(features) - if not self._num_shards: - raise ValueError('`sharded_features` should not be empty.') - - if sharded_labels is not None: - if not isinstance(sharded_labels, _PerShardOutput): - raise ValueError('sharded_labels` must have type `_PerShardOutput`.') - - self._has_labels = True - labels = sharded_labels.as_list() - if self._num_shards != len(labels): - raise ValueError( - 'Length of `sharded_features` and `sharded_labels` mismatch.') - - if self._has_labels: - for (f, l) in zip(features, labels): - self.append_tuple((f, l)) - else: - for f in features: - self.append_tuple(f) - - self._sharded = True - self._frozen = True - - def _extract_key_names(self, tensor_or_dict): - if tensor_or_dict is None: - return [] - - return tensor_or_dict.keys() if isinstance(tensor_or_dict, dict) else [] - - def _validate(self, features, labels): - has_labels = labels is not None - feature_names = self._extract_key_names(features) - label_names = self._extract_key_names(labels) - - if self._initialized: - self._sharded = True - # The following should never happen. - assert feature_names == self._feature_names, 'feature keys mismatched' - assert label_names == self._label_names, 'label keys mismatched' - assert has_labels == self._has_labels, 'label presence mismatched' - else: - self._initialized = True - self._feature_names = feature_names - self._label_names = label_names - self._has_labels = has_labels - - @property - def sharded(self): - if not self._frozen: - raise RuntimeError('_InputsHolder has not been frozen yet.') - return self._sharded - - @property - def num_shards(self): - if not self._frozen: - raise RuntimeError('_InputsHolder has not been frozen yet.') - return self._num_shards - - def append_tuple(self, inputs): - """Appends `inputs` for one shard into holder. - - Args: - inputs: The return from `input_fn`, which could be features or tuple of - (features, labels). After the first `inputs` appended into - `_InputsHolder`, the structure of `features` and `labels is recorded. - Any future invocation should provide the `inputs` with same structure. - - Raises: - RuntimeError: If the internal data has been frozen already. - """ - if self._frozen: - raise RuntimeError('InputsHolder has frozen, which cannot be mutated.') - - # input_fn may return either features or (features, labels) - if isinstance(inputs, tuple): - features, labels = inputs - else: - features, labels = inputs, None - - self._validate(features, labels) - - self._feature_list.append(features) - if labels is not None: - self._label_list.append(labels) - - def as_features_and_labels_tuple(self): - """Returns features and labels as grouped tuple. - - This is intended to be used to pass features and labels for all shards from - input_fn to model_fn as the parent class `Estimator` does not have the - concept of shards. So, grouped tuple is required. - - Once called, the internal data is frozen and `append_tuple` cannot be - invoked anymore. - - Returns: - A tuple of features and labels. Both have type `_PerShardOutput`, holding - the inputs for all shards. `labels` could be `None`. - - Raises: - RuntimeError: If the internal data has not been initialized. - """ - self._frozen = True - if not self._initialized: - raise RuntimeError('InputsHolder has not been initialized.') - - assert len(self._feature_list) == self._num_shards - if not self._label_list or all(l is None for l in self._label_list): - return _PerShardOutput(self._feature_list), None - - assert len(self._label_list) == self._num_shards - return (_PerShardOutput(self._feature_list), - _PerShardOutput(self._label_list)) - - def as_sharded_flattened_inputs(self): - """Flatten the features and label as tensor lists for all shards. - - Flattened tensor list contains all tensors in `features` (dict) and `labels` - (dict). Conceptually, it has the predicated structure like: - - ```python - flatten_list = [] - for name in features: - flatten_list.append(features[name]) - for name in labels: - flatten_list.append(labels[name]) - ``` - - This method handles the label is None case and single tensor case nicely. - - Once called, the internal data is frozen and `append_tuple` cannot be - invokded anymore. - - Returns: - A list of flattened inputs one for each shard. - - Raises: - RuntimeError: If the internal data has not been initialized. - ValueError: If the inputs are sharded. - """ - self._frozen = True - if not self._initialized: - raise RuntimeError('InputsHolder has not been initialized.') - if not self._sharded: - raise ValueError('Inputs are not sharded.') - - sharded_inputs = [] - - for shard in range(self._num_shards): - flattened_inputs = self._as_flattened_inputs( - self._feature_list[shard], - self._label_list[shard] if self._has_labels else None) - sharded_inputs.append(flattened_inputs) - - return sharded_inputs - - def as_flattened_inputs(self): - """Flatten the features and label as a single tensor list for one host.""" - self._frozen = True - if not self._initialized: - raise RuntimeError('InputsHolder has not been initialized.') - if self._sharded: - raise ValueError('Inputs are sharded.') - - return self._as_flattened_inputs( - self._feature_list[0], - self._label_list[0] if self._has_labels else None) - - def _as_flattened_inputs(self, features, labels): - """Flattens the `features` and `labels` to a single tensor list.""" - flattened_inputs = [] - if self._feature_names: - # We need a fixed ordering for enqueueing and dequeueing. - flattened_inputs.extend([features[name] for name in self._feature_names]) - else: - flattened_inputs.append(features) - - if labels is not None: - if self._label_names: - # We need a fixed ordering for enqueueing and dequeueing. - flattened_inputs.extend([labels[name] for name in self._label_names]) + if _WRAP_INPUT_FN_INTO_WHILE_LOOP: + return _wrap_computation_in_while_loop(device=host_device, + op_fn=enqueue_fn) else: - flattened_inputs.append(labels) - return flattened_inputs - - def unflatten_features_and_labels(self, flattened_inputs): - """Restores the flattened inputs to original features and labels form. - - Once called, the internal data is frozen and `append_tuple` cannot be - invokded anymore. - - Args: - flattened_inputs: Flattened inputs for one each, which should be created - by the `as_sharded_flattened_inputs` API. - - Returns: - A tuple of (`features`, `labels`), where `labels` could be None. - Each one, if present, should have identical structure (single tensor vs - dict) as the one returned by input_fn. - - Raises: - RuntimeError: If the internal data has not been initialized. - ValueError: If the number of expected tensors from `flattened_inputs` - mismatches the recorded structure. - """ - self._frozen = True - if not self._initialized: - raise RuntimeError('InputsHolder has not been initialized.') - - expected_num_features = (len(self._feature_names) if self._feature_names - else 1) - if self._has_labels: - expected_num_labels = (len(self._label_names) if self._label_names - else 1) - else: - expected_num_labels = 0 - - expected_num_tensors = expected_num_features + expected_num_labels - - if expected_num_tensors != len(flattened_inputs): - raise ValueError( - 'The number of flattened tensors mismatches expected num. ' - 'Expected {}, got {}'.format(expected_num_tensors, - len(flattened_inputs))) - if self._feature_names: - unflattened_features = dict(zip(self._feature_names, - flattened_inputs[:expected_num_features])) - else: - # Single tensor case - unflattened_features = flattened_inputs[0] - - if expected_num_labels == 0: - unflattened_label = None - elif self._label_names: - unflattened_label = dict(zip(self._label_names, - flattened_inputs[expected_num_features:])) - else: - # Single tensor case. - unflattened_label = flattened_inputs[expected_num_features] - - return unflattened_features, unflattened_label + return enqueue_fn() class _ModelFnWrapper(object): @@ -840,20 +931,17 @@ class _ModelFnWrapper(object): train and eval step. """ - def __init__(self, model_fn, config, params, mode, train_batch_size, - eval_batch_size): + def __init__(self, model_fn, config, params, ctx): self._model_fn = model_fn self._config = config self._params = params - self._mode = mode - self._train_batch_size = train_batch_size - self._eval_batch_size = eval_batch_size + self._ctx = ctx def call_without_tpu(self, features, labels): # Let CrossShardOptimizer be called without TPU in model_fn, since it's # common to set the train_op even when running evaluate() or predict(). with tpu_function.tpu_shard_context(1): - return self._call_model_fn(features, labels, use_tpu=False) + return self._call_model_fn(features, labels) def convert_to_single_tpu_train_step(self, dequeue_fn): """Converts user provided model_fn` as a single train step on TPU. @@ -883,7 +971,7 @@ class _ModelFnWrapper(object): features, labels = dequeue_fn() estimator_spec = self._verify_estimator_spec( - self._call_model_fn(features, labels, use_tpu=True)) + self._call_model_fn(features, labels)) loss, train_op = estimator_spec.loss, estimator_spec.train_op with ops.control_dependencies([train_op]): return array_ops.identity(loss) @@ -915,13 +1003,13 @@ class _ModelFnWrapper(object): A tuple of eval_fn and eval_metrics. The eval_fn representing the eval step for TPU. and eval_metrics is an `_EvalMetrics` instance. """ - eval_metrics = _EvalMetrics() + eval_metrics = _EvalMetrics(self._ctx) def eval_step(total_loss): """Evaluation step function for use inside a while loop.""" features, labels = dequeue_fn() - tpu_estimator_spec = self._call_model_fn(features, labels, use_tpu=True) + tpu_estimator_spec = self._call_model_fn(features, labels) if not isinstance(tpu_estimator_spec, TPUEstimatorSpec): raise RuntimeError( 'estimator_spec used by TPU evaluation must have type' @@ -935,11 +1023,7 @@ class _ModelFnWrapper(object): return math_ops.add(total_loss, loss) return eval_step, eval_metrics - @property - def config(self): - return self._config - - def _call_model_fn(self, features, labels, use_tpu): + def _call_model_fn(self, features, labels): """Calls the model_fn with required parameters.""" model_fn_args = util.fn_args(self._model_fn) kwargs = {} @@ -950,12 +1034,11 @@ class _ModelFnWrapper(object): if 'labels' in model_fn_args: kwargs['labels'] = labels - else: - if labels is not None: - raise ValueError( - 'model_fn does not take labels, but input_fn returns labels.') + elif labels is not None: + raise ValueError( + 'model_fn does not take labels, but input_fn returns labels.') if 'mode' in model_fn_args: - kwargs['mode'] = self._mode + kwargs['mode'] = self._ctx.mode if 'config' in model_fn_args: kwargs['config'] = config if 'params' in model_fn_args: @@ -966,16 +1049,16 @@ class _ModelFnWrapper(object): 'model_fn ({}) does not include params argument, ' 'required by TPUEstimator to pass batch size as ' 'params[\'batch_size\']'.format(self._model_fn)) - if self._mode == model_fn_lib.ModeKeys.TRAIN: - params[_BATCH_SIZE_KEY] = _per_shard_batch_size( - self._train_batch_size, config, use_tpu) - elif (self._mode == model_fn_lib.ModeKeys.EVAL and - self._eval_batch_size is not None): - params[_BATCH_SIZE_KEY] = _per_shard_batch_size( - self._eval_batch_size, config, use_tpu) + + batch_size_for_model_fn = self._ctx.batch_size_for_model_fn + if batch_size_for_model_fn is not None: + params[_BATCH_SIZE_KEY] = batch_size_for_model_fn estimator_spec = self._model_fn(features=features, **kwargs) - if (not use_tpu) and isinstance(estimator_spec, TPUEstimatorSpec): + if (self._ctx.is_running_on_cpu() and + isinstance(estimator_spec, TPUEstimatorSpec)): + # The estimator_spec will be passed to `Estimator` directly, which expects + # type `EstimatorSpec`. return estimator_spec.as_estimator_spec() else: return estimator_spec @@ -998,7 +1081,8 @@ class _ModelFnWrapper(object): class _EvalMetrics(object): """Class wraps TPUEstimator.eval_metrics.""" - def __init__(self): + def __init__(self, ctx): + self._ctx = ctx self._metric_fn = None self._is_dict = False self._tensor_keys = [] @@ -1081,7 +1165,7 @@ class _EvalMetrics(object): raise RuntimeError('Eval metrics have not been recorded yet') return self._tensors - def to_metric_metric_ops_for_tpu(self, run_config, dummy_update_op): + def to_metric_metric_ops_for_tpu(self, dummy_update_op): """Creates the eval_metric_ops now based on the TPU outfeed. `eval_metric_ops` is defined in `EstimatorSpec`. From all shards, tensors @@ -1090,7 +1174,6 @@ class _EvalMetrics(object): metric fn. Args: - run_config: A `RunConfig` instance. dummy_update_op: A dummy update op. Returns: @@ -1102,9 +1185,7 @@ class _EvalMetrics(object): RuntimeError: If outfeed tensor is scalar. """ - num_shards = run_config.tpu_config.num_shards - job = _tpu_job(run_config, model_fn_lib.ModeKeys.EVAL) - job_device = '' if job is None else ('/job:%s' % job) + num_cores = self._ctx.num_cores # For each i, dequeue_ops[i] is a list containing the tensors from all # shards. This list is concatenated later. @@ -1113,8 +1194,9 @@ class _EvalMetrics(object): dequeue_ops.append([]) # Outfeed ops execute on each JF node. - for i in xrange(num_shards): - with ops.device('%s/task:%d/device:TPU:%d' % (job_device, i / 8, i % 8)): + tpu_device_placement_fn = self._ctx.tpu_device_placement_function + for i in xrange(num_cores): + with ops.device(tpu_device_placement_fn(i)): outfeed_tensors = tpu_ops.outfeed_dequeue_tuple( dtypes=self._tensor_dtypes, shapes=self._tensor_shapes) for j, item in enumerate(outfeed_tensors): @@ -1122,7 +1204,7 @@ class _EvalMetrics(object): # It is assumed evaluation always happends on single host TPU system. So, # place all ops on tpu host if possible. - with ops.device('{}/device:CPU:0'.format(job_device)): + with ops.device(self._ctx.tpu_host_placement_function(core_id=0)): for i, item in enumerate(dequeue_ops): if dequeue_ops[i][0].shape.ndims == 0: raise RuntimeError( @@ -1167,9 +1249,9 @@ class TPUEstimator(estimator_lib.Estimator): specify `train_batch_size` in constructor, and then get the batch size for each shard in `input_fn` and `model_fn` by `params['batch_size']`. If `TPUConfig.per_host_input_for_training` is `True`, `input_fn` is invoked per - host rather than per shard. In this case, a global batch size is transformed a + host rather than per core. In this case, a global batch size is transformed a per-host batch size in params for `input_fn`, but `model_fn` still gets - per-shard batch size. + per-core batch size. For evaluation, if `eval_batch_size` is None, it is executed on CPU, even if `use_tpu` is `True`. If `eval_batch_size` is not `None`, it is executed on @@ -1327,9 +1409,7 @@ class TPUEstimator(estimator_lib.Estimator): # We cannot store config and params in this constructor as parent # constructor might change them, such as assigning a temp dir for # config.model_dir. - model_function = _augment_model_fn(model_fn, train_batch_size, - eval_batch_size, use_tpu, - batch_axis) + model_function = self._augment_model_fn(model_fn, batch_axis) # Passing non-None params as wrapped model_fn has it. params = params or {} @@ -1338,12 +1418,13 @@ class TPUEstimator(estimator_lib.Estimator): model_dir=model_dir, config=config, params=params) - self._use_tpu = use_tpu - self._train_batch_size = train_batch_size - self._eval_batch_size = eval_batch_size self._iterations_per_training_loop = ( self._config.tpu_config.iterations_per_loop) + # All properties passed to _TPUContext are immutable. + self._ctx = _TPUContext(self._config, train_batch_size, eval_batch_size, + use_tpu) + def _create_global_step(self, graph): """Creates a global step suitable for TPUs. @@ -1359,10 +1440,10 @@ class TPUEstimator(estimator_lib.Estimator): return _create_global_step(graph) def _convert_train_steps_to_hooks(self, steps, max_steps): - if _is_running_on_cpu(self._use_tpu, model_fn_lib.ModeKeys.TRAIN, - self._eval_batch_size): - return super(TPUEstimator, self)._convert_train_steps_to_hooks( - steps, max_steps) + with self._ctx.with_mode(model_fn_lib.ModeKeys.TRAIN) as ctx: + if ctx.is_running_on_cpu(): + return super(TPUEstimator, self)._convert_train_steps_to_hooks( + steps, max_steps) # On TPU. if steps is None and max_steps is None: @@ -1380,9 +1461,9 @@ class TPUEstimator(estimator_lib.Estimator): steps, max_steps)] def _convert_eval_steps_to_hooks(self, steps): - if _is_running_on_cpu(self._use_tpu, model_fn_lib.ModeKeys.EVAL, - self._eval_batch_size): - return super(TPUEstimator, self)._convert_eval_steps_to_hooks(steps) + with self._ctx.with_mode(model_fn_lib.ModeKeys.EVAL) as ctx: + if ctx.is_running_on_cpu(): + return super(TPUEstimator, self)._convert_eval_steps_to_hooks(steps) if steps is None: raise ValueError('Evaluate `steps` must be set on TPU. Cannot be `None`.') @@ -1422,197 +1503,115 @@ class TPUEstimator(estimator_lib.Estimator): if 'config' in input_fn_args: kwargs['config'] = config - # Setting the batch size in params first. This helps user to have same - # input_fn for use_tpu=True/False. - if mode == model_fn_lib.ModeKeys.TRAIN: - kwargs['params'][_BATCH_SIZE_KEY] = ( - _per_shard_batch_size(self._train_batch_size, config, self._use_tpu) - if not config.tpu_config.per_host_input_for_training else - self._train_batch_size) - elif (mode == model_fn_lib.ModeKeys.EVAL and - self._eval_batch_size is not None): - # For TPU evaluation, input_fn is invoked for one host (instead of shard). - kwargs['params'][_BATCH_SIZE_KEY] = self._eval_batch_size + with self._ctx.with_mode(mode) as ctx: + # Setting the batch size in params first. This helps user to have same + # input_fn for use_tpu=True/False. + batch_size_for_input_fn = ctx.batch_size_for_input_fn + if batch_size_for_input_fn is not None: + kwargs['params'][_BATCH_SIZE_KEY] = batch_size_for_input_fn - if _is_running_on_cpu(self._use_tpu, mode, self._eval_batch_size): - with ops.device('/device:CPU:0'): - return input_fn(**kwargs) - - job = _tpu_job(config, mode) - def placement_function(index): - if job is None: - return '/replica:0/task:0/device:CPU:0' - else: - return '/job:%s/task:%d/device:CPU:0' % (job, index / 8) - - if mode == model_fn_lib.ModeKeys.TRAIN: - if not config.tpu_config.per_host_input_for_training: - # Now for TPU training. - num_shards = config.tpu_config.num_shards - inputs = _InputsHolder(num_shards=num_shards) - for i in range(config.tpu_config.num_shards): - with ops.device(placement_function(i)): - inputs.append_tuple(input_fn(**kwargs)) - return inputs.as_features_and_labels_tuple() - else: - # TODO(xiejw): Extend this to multi-host support. - with ops.device(placement_function(0)): + if ctx.is_running_on_cpu(): + with ops.device('/device:CPU:0'): return input_fn(**kwargs) - # Now for TPU evaluation. - with ops.device(placement_function(0)): - return input_fn(**kwargs) + # For TPU computation, input_fn should be invoked in a tf.while_loop for + # performance. While constructing the tf.while_loop, the structure of + # inputs returned by the `input_fn` needs to be recorded. The structure + # includes whether features or labels is dict or single Tensor, dict keys, + # tensor shapes, and dtypes. The recorded structure is used to create the + # infeed dequeue ops, which must be wrapped and passed as a Fn, called + # inside the TPU computation, as the TPU computation is wrapped inside a + # tf.while_loop also. So, we either pass input_fn to model_fn or pass + # dequeue_fn to model_fn. Here, `input_fn` is passed directly as + # `features` in `model_fn` signature. + def _input_fn(): + return input_fn(**kwargs) + return _input_fn + + def _augment_model_fn(self, model_fn, batch_axis): + """Returns a new model_fn, which wraps the TPU support.""" + + def _model_fn(features, labels, mode, config, params): + """A Estimator `model_fn` for TPUEstimator.""" + with self._ctx.with_mode(mode) as ctx: + model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx) + + # TODO(jhseu): Move to PREDICT to TPU. + if ctx.is_running_on_cpu(): + logging.info('Running %s on CPU', mode) + return model_fn_wrapper.call_without_tpu(features, labels) + + assert labels is None, '`labels` passed to `model_fn` must be `None`.' + # TPUEstimator._call_input_fn passes `input_fn` as features to here. + assert callable(features), '`input_fn` is not callable.' + input_fn = features + + input_holders = _InputPipeline(input_fn, batch_axis, ctx) + enqueue_ops, dequeue_fn = ( + input_holders.generate_infeed_enqueue_ops_and_dequeue_fn()) + + if mode == model_fn_lib.ModeKeys.TRAIN: + loss = _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn) + hooks = [ + TPUInfeedOutfeedSessionHook(ctx, enqueue_ops), + training.LoggingTensorHook( + {'loss': array_ops.identity(loss), + 'step': training.get_global_step()}, + every_n_secs=30) + ] + summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) + with ops.control_dependencies([loss]): + update_ops = _sync_variables_ops() + + # Validate the TPU training graph to catch basic errors + _validate_tpu_training_graph() + + return model_fn_lib.EstimatorSpec( + mode, + loss=loss, + training_hooks=hooks, + train_op=control_flow_ops.group(*update_ops)) + + # Now eval. + total_loss, eval_metric_ops = _eval_on_tpu_system( + ctx, model_fn_wrapper, dequeue_fn) + iterations_per_loop_var = _create_or_get_iterations_per_loop() + mean_loss = math_ops.div( + total_loss, + math_ops.cast(iterations_per_loop_var, dtype=total_loss.dtype)) + + # Creates a dummy metric update_op for all metrics. Estimator expects + # all metrics in eval_metric_ops have update_op and calls them one by + # one. The real metric update_ops are invoked in a separated thread. So, + # here give Estimator the dummy op for all metrics. + with ops.control_dependencies([mean_loss]): + # After TPU evaluation computation is done (the mean_loss tensor), + # reads all variables back from TPU and updates the eval step counter + # properly + internal_ops_to_run = _sync_variables_ops() + internal_ops_to_run.append( + _increase_eval_step_op(iterations_per_loop_var)) + with ops.control_dependencies(internal_ops_to_run): + dummy_update_op = control_flow_ops.no_op() + + eval_metric_ops, eval_update_ops = ( + eval_metric_ops.to_metric_metric_ops_for_tpu(dummy_update_op)) + hooks = [ + TPUInfeedOutfeedSessionHook(ctx, enqueue_ops, eval_update_ops), + ] + + return model_fn_lib.EstimatorSpec( + mode, + loss=mean_loss, + evaluation_hooks=hooks, + eval_metric_ops=eval_metric_ops) + return _model_fn -# TODO(b/64607814): Ensure batch_axis works with nested structures. -def _create_infeed_enqueue_ops_and_dequeue_fn(inputs_holder, run_config, - batch_axis, mode): - """Utility to convert input_fn to enqueue and dequeue fns for TPU. - - Args: - inputs_holder: An `_InputsHolder` holding features and labels. - run_config: A `RunConfig` instance. - batch_axis: A python list of batch dimensions. - mode: ModeKeys - - Returns: - A tuple of (dequeue_fn, enqueue_fn) - """ - if inputs_holder.sharded: - sharded_inputs = inputs_holder.as_sharded_flattened_inputs() - - infeed_queue = tpu_feed.InfeedQueue( - number_of_tuple_elements=len(sharded_inputs[0])) - infeed_queue.set_configuration_from_sharded_input_tensors(sharded_inputs) - else: - unsharded_inputs = inputs_holder.as_flattened_inputs() - infeed_queue = tpu_feed.InfeedQueue( - tuple_types=[t.dtype for t in unsharded_inputs], - tuple_shapes=[t.shape for t in unsharded_inputs], - shard_dimensions=batch_axis) - infeed_queue.set_number_of_shards(inputs_holder.num_shards) - - def dequeue_fn(): - """dequeue_fn is used by the train_step in TPU to retrieve the tensors.""" - values = infeed_queue.generate_dequeue_op() - return inputs_holder.unflatten_features_and_labels(values) - - def tpu_ordinal_function(index): - """Return the TPU ordinal associated with a shard. - - Required because the enqueue ops are placed on CPU. - - Args: - index: the shard index - - Returns: - The ordinal of the TPU device the shard's infeed should be placed on. - """ - return index % 8 - - def enqueue_fn(): - """enqueue_fn is used to add ops to the graph to send tensors.""" - if inputs_holder.sharded: - return infeed_queue.generate_enqueue_ops( - sharded_inputs, tpu_ordinal_function=tpu_ordinal_function) - else: - job = _tpu_job(run_config, mode) - def placement_function(index): - if job is None: - return '/replica:0/task:0/device:CPU:0' - else: - # This assumes that if using more than 8 shards, - # the job configuration varies 'task'. - return '/job:%s/task:%d/device:CPU:0' % (job, index / 8) - return infeed_queue.split_inputs_and_generate_enqueue_ops( - unsharded_inputs, placement_function=placement_function) - - return (dequeue_fn, enqueue_fn) - - -def _augment_model_fn(model_fn, train_batch_size, eval_batch_size, use_tpu, - batch_axis): - """Returns a new model_fn, which wraps the TPU support.""" - - def _model_fn(features, labels, mode, config, params): - """A Estimator `model_fn` for TPUEstimator.""" - model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, mode, - train_batch_size, eval_batch_size) - - # TODO(jhseu): Move to PREDICT to TPU. - if _is_running_on_cpu(use_tpu, mode, eval_batch_size): - logging.info('Running %s on CPU', mode) - return model_fn_wrapper.call_without_tpu(features, labels) - - inputs = _InputsHolder(features=features, labels=labels, - num_shards=config.tpu_config.num_shards) - - dequeue_fn, enqueue_fn = _create_infeed_enqueue_ops_and_dequeue_fn( - inputs, config, batch_axis, mode) - - if mode == model_fn_lib.ModeKeys.TRAIN: - loss = _train_on_tpu_system(model_fn_wrapper, dequeue_fn) - hooks = [ - TPUInfeedOutfeedSessionHook(config, mode, enqueue_fn), - training.LoggingTensorHook( - {'loss': array_ops.identity(loss), - 'step': training.get_global_step()}, - every_n_secs=30) - ] - summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) - with ops.control_dependencies([loss]): - update_ops = _sync_variables_ops() - - # Validate the TPU training graph to catch basic errors - _validate_tpu_training_graph() - - return model_fn_lib.EstimatorSpec( - mode, - loss=loss, - training_hooks=hooks, - train_op=control_flow_ops.group(*update_ops)) - - # Now eval. - total_loss, eval_metric_ops = _eval_on_tpu_system( - model_fn_wrapper, dequeue_fn) - iterations_per_loop_var = _create_iterations_per_loop() - mean_loss = math_ops.div( - total_loss, - math_ops.cast(iterations_per_loop_var, dtype=total_loss.dtype)) - - # Creates a dummy metric update_op for all metrics. Estimator expects all - # metrics in eval_metric_ops have update_op and calls them one by one. The - # real metric update_ops are invoked in a separated thread. So, here give - # Estimator the dummy op for all metrics. - with ops.control_dependencies([mean_loss]): - # After TPU evaluation computation is done (the mean_loss tensor), reads - # all variables back from TPU and updates the eval step counter properly. - internal_ops_to_run = _sync_variables_ops() - internal_ops_to_run.append( - _increase_eval_step_op(iterations_per_loop_var)) - with ops.control_dependencies(internal_ops_to_run): - dummy_update_op = control_flow_ops.no_op() - - eval_metric_ops, eval_update_ops = ( - eval_metric_ops.to_metric_metric_ops_for_tpu( - config, dummy_update_op)) - hooks = [ - TPUInfeedOutfeedSessionHook(config, mode, enqueue_fn, eval_update_ops), - ] - - return model_fn_lib.EstimatorSpec( - mode, - loss=mean_loss, - evaluation_hooks=hooks, - eval_metric_ops=eval_metric_ops) - return _model_fn - - -def _eval_on_tpu_system(model_fn_wrapper, dequeue_fn): +def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - config = model_fn_wrapper.config.tpu_config - num_shards = config.num_shards - iterations_per_loop_var = _create_iterations_per_loop() + num_cores = ctx.num_cores + iterations_per_loop_var = _create_or_get_iterations_per_loop() single_tpu_eval_step, eval_metric_ops = ( model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn)) @@ -1625,15 +1624,15 @@ def _eval_on_tpu_system(model_fn_wrapper, dequeue_fn): (loss,) = tpu.shard(multi_tpu_eval_steps_on_single_shard, inputs=[], - num_shards=num_shards, + num_shards=num_cores, outputs_from_all_shards=False) return loss, eval_metric_ops -def _train_on_tpu_system(model_fn_wrapper, dequeue_fn): +def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - num_shards = model_fn_wrapper.config.tpu_config.num_shards - iterations_per_loop_var = _create_iterations_per_loop() + num_cores = ctx.num_cores + iterations_per_loop_var = _create_or_get_iterations_per_loop() single_tpu_train_step = model_fn_wrapper.convert_to_single_tpu_train_step( dequeue_fn) @@ -1647,11 +1646,27 @@ def _train_on_tpu_system(model_fn_wrapper, dequeue_fn): (loss,) = tpu.shard(multi_tpu_train_steps_on_single_shard, inputs=[], - num_shards=num_shards, + num_shards=num_cores, outputs_from_all_shards=False) return loss +def _wrap_computation_in_while_loop(device, op_fn): + """Wraps the ops generated by `op_fn` in tf.while_loop.""" + def computation(i): + with ops.control_dependencies(op_fn()): + return i + 1 + + iterations_per_loop_var = _create_or_get_iterations_per_loop() + # By setting parallel_iterations=1, the parallel execution in while_loop is + # basically turned off. + with ops.device(device): + iterations = array_ops.identity(iterations_per_loop_var) + return control_flow_ops.while_loop( + lambda i: i < iterations, + computation, [constant_op.constant(0)], parallel_iterations=1) + + def _validate_tpu_training_graph(): """Validate graph before running distributed training. From 71bdc0efa737e3094033f0c6ea3779b1fc3c8a94 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Oct 2017 11:34:40 -0700 Subject: [PATCH 32/41] Formatting metric_ops. PiperOrigin-RevId: 172910546 --- .../contrib/metrics/python/ops/metric_ops.py | 591 +++++++++++------- 1 file changed, 382 insertions(+), 209 deletions(-) diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 85c8e9038ac..09485c4fa2a 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -56,7 +56,10 @@ def _safe_div(numerator, denominator, name): name=name) -def _create_local(name, shape, collections=None, validate_shape=True, +def _create_local(name, + shape, + collections=None, + validate_shape=True, dtype=dtypes.float32): """Creates a new local variable. @@ -87,7 +90,9 @@ def _assert_weights_rank(weights, values): return check_ops.assert_rank_in(weights, (0, array_ops.rank(values))) -def _count_condition(values, weights=None, metrics_collections=None, +def _count_condition(values, + weights=None, + metrics_collections=None, updates_collections=None): """Sums the weights of cases where the given values are True. @@ -134,7 +139,9 @@ def _count_condition(values, weights=None, metrics_collections=None, return value_tensor, update_op -def streaming_true_positives(predictions, labels, weights=None, +def streaming_true_positives(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -168,12 +175,17 @@ def streaming_true_positives(predictions, labels, weights=None, tuple. """ return metrics.true_positives( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_true_negatives(predictions, labels, weights=None, +def streaming_true_negatives(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -206,20 +218,22 @@ def streaming_true_negatives(predictions, labels, weights=None, either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope( - name, 'true_negatives', (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'true_negatives', + (predictions, labels, weights)): predictions, labels, weights = _remove_squeezable_dimensions( predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) - is_true_negative = math_ops.logical_and(math_ops.equal(labels, False), - math_ops.equal(predictions, False)) + is_true_negative = math_ops.logical_and( + math_ops.equal(labels, False), math_ops.equal(predictions, False)) return _count_condition(is_true_negative, weights, metrics_collections, updates_collections) -def streaming_false_positives(predictions, labels, weights=None, +def streaming_false_positives(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -253,12 +267,17 @@ def streaming_false_positives(predictions, labels, weights=None, tuple. """ return metrics.false_positives( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_false_negatives(predictions, labels, weights=None, +def streaming_false_negatives(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -291,9 +310,12 @@ def streaming_false_negatives(predictions, labels, weights=None, or tuple. """ return metrics.false_negatives( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) # TODO(ptucker): Move this somewhere common, to share with ops/losses/losses.py. @@ -317,17 +339,18 @@ def _broadcast_weights(weights, values): with ops.name_scope(None, 'broadcast_weights', (values, weights)) as scope: weights_shape = weights.get_shape() values_shape = values.get_shape() - if (weights_shape.is_fully_defined() and - values_shape.is_fully_defined() and + if (weights_shape.is_fully_defined() and values_shape.is_fully_defined() and weights_shape.is_compatible_with(values_shape)): return weights with ops.control_dependencies((_assert_weights_rank(weights, values),)): - return math_ops.multiply( - weights, array_ops.ones_like(values), name=scope) + return math_ops.multiply(weights, array_ops.ones_like(values), name=scope) -def streaming_mean(values, weights=None, metrics_collections=None, - updates_collections=None, name=None): +def streaming_mean(values, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): """Computes the (weighted) mean of the given values. The `streaming_mean` function creates two local variables, `total` and `count` @@ -365,12 +388,18 @@ def streaming_mean(values, weights=None, metrics_collections=None, or tuple. """ return metrics.mean( - values=values, weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + values=values, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) -def streaming_mean_tensor(values, weights=None, metrics_collections=None, - updates_collections=None, name=None): +def streaming_mean_tensor(values, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): """Computes the element-wise (weighted) mean of the given tensors. In contrast to the `streaming_mean` function which returns a scalar with the @@ -412,12 +441,18 @@ def streaming_mean_tensor(values, weights=None, metrics_collections=None, or tuple. """ return metrics.mean_tensor( - values=values, weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + values=values, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) -def streaming_accuracy(predictions, labels, weights=None, - metrics_collections=None, updates_collections=None, +def streaming_accuracy(predictions, + labels, + weights=None, + metrics_collections=None, + updates_collections=None, name=None): """Calculates how often `predictions` matches `labels`. @@ -462,13 +497,19 @@ def streaming_accuracy(predictions, labels, weights=None, tuple. """ return metrics.accuracy( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_precision(predictions, labels, weights=None, - metrics_collections=None, updates_collections=None, +def streaming_precision(predictions, + labels, + weights=None, + metrics_collections=None, + updates_collections=None, name=None): """Computes the precision of the predictions with respect to the labels. @@ -512,13 +553,19 @@ def streaming_precision(predictions, labels, weights=None, tuple. """ return metrics.precision( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_recall(predictions, labels, weights=None, - metrics_collections=None, updates_collections=None, +def streaming_recall(predictions, + labels, + weights=None, + metrics_collections=None, + updates_collections=None, name=None): """Computes the recall of the predictions with respect to the labels. @@ -560,12 +607,17 @@ def streaming_recall(predictions, labels, weights=None, tuple. """ return metrics.recall( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def _true_negatives(labels, predictions, weights=None, +def _true_negatives(labels, + predictions, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -597,20 +649,22 @@ def _true_negatives(labels, predictions, weights=None, either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope( - name, 'true_negatives', (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'true_negatives', + (predictions, labels, weights)): predictions, labels, weights = _remove_squeezable_dimensions( predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) - is_true_negative = math_ops.logical_and(math_ops.equal(labels, False), - math_ops.equal(predictions, False)) + is_true_negative = math_ops.logical_and( + math_ops.equal(labels, False), math_ops.equal(predictions, False)) return _count_condition(is_true_negative, weights, metrics_collections, updates_collections) -def streaming_false_positive_rate(predictions, labels, weights=None, +def streaming_false_positive_rate(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -657,30 +711,35 @@ def streaming_false_positive_rate(predictions, labels, weights=None, either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope( - name, 'false_positive_rate', (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'false_positive_rate', + (predictions, labels, weights)): predictions, labels, weights = _remove_squeezable_dimensions( predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) false_p, false_positives_update_op = metrics.false_positives( - labels, predictions, weights, metrics_collections=None, - updates_collections=None, name=None) + labels, + predictions, + weights, + metrics_collections=None, + updates_collections=None, + name=None) true_n, true_negatives_update_op = _true_negatives( - labels, predictions, weights, metrics_collections=None, - updates_collections=None, name=None) + labels, + predictions, + weights, + metrics_collections=None, + updates_collections=None, + name=None) def compute_fpr(fp, tn, name): return array_ops.where( - math_ops.greater(fp + tn, 0), - math_ops.div(fp, fp + tn), - 0, - name) + math_ops.greater(fp + tn, 0), math_ops.div(fp, fp + tn), 0, name) fpr = compute_fpr(false_p, true_n, 'value') - update_op = compute_fpr( - false_positives_update_op, true_negatives_update_op, 'update_op') + update_op = compute_fpr(false_positives_update_op, true_negatives_update_op, + 'update_op') if metrics_collections: ops.add_to_collections(metrics_collections, fpr) @@ -691,7 +750,9 @@ def streaming_false_positive_rate(predictions, labels, weights=None, return fpr, update_op -def streaming_false_negative_rate(predictions, labels, weights=None, +def streaming_false_negative_rate(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -738,30 +799,35 @@ def streaming_false_negative_rate(predictions, labels, weights=None, either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope( - name, 'false_negative_rate', (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'false_negative_rate', + (predictions, labels, weights)): predictions, labels, weights = _remove_squeezable_dimensions( predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) false_n, false_negatives_update_op = metrics.false_negatives( - labels, predictions, weights, metrics_collections=None, - updates_collections=None, name=None) + labels, + predictions, + weights, + metrics_collections=None, + updates_collections=None, + name=None) true_p, true_positives_update_op = metrics.true_positives( - labels, predictions, weights, metrics_collections=None, - updates_collections=None, name=None) + labels, + predictions, + weights, + metrics_collections=None, + updates_collections=None, + name=None) def compute_fnr(fn, tp, name): return array_ops.where( - math_ops.greater(fn + tp, 0), - math_ops.div(fn, fn + tp), - 0, - name) + math_ops.greater(fn + tp, 0), math_ops.div(fn, fn + tp), 0, name) fnr = compute_fnr(false_n, true_p, 'value') - update_op = compute_fnr( - false_negatives_update_op, true_positives_update_op, 'update_op') + update_op = compute_fnr(false_negatives_update_op, true_positives_update_op, + 'update_op') if metrics_collections: ops.add_to_collections(metrics_collections, fnr) @@ -772,8 +838,11 @@ def streaming_false_negative_rate(predictions, labels, weights=None, return fnr, update_op -def _streaming_confusion_matrix_at_thresholds( - predictions, labels, thresholds, weights=None, includes=None): +def _streaming_confusion_matrix_at_thresholds(predictions, + labels, + thresholds, + weights=None, + includes=None): """Computes true_positives, false_negatives, true_negatives, false_positives. This function creates up to four local variables, `true_positives`, @@ -861,8 +930,8 @@ def _streaming_confusion_matrix_at_thresholds( if weights is not None: broadcast_weights = weights_broadcast_ops.broadcast_weights( math_ops.to_float(weights), predictions) - weights_tiled = array_ops.tile(array_ops.reshape( - broadcast_weights, [1, -1]), [num_thresholds, 1]) + weights_tiled = array_ops.tile( + array_ops.reshape(broadcast_weights, [1, -1]), [num_thresholds, 1]) thresh_tiled.get_shape().assert_is_compatible_with( weights_tiled.get_shape()) else: @@ -877,8 +946,9 @@ def _streaming_confusion_matrix_at_thresholds( math_ops.logical_and(label_is_pos, pred_is_pos)) if weights_tiled is not None: is_true_positive *= weights_tiled - update_ops['tp'] = state_ops.assign_add( - true_positives, math_ops.reduce_sum(is_true_positive, 1)) + update_ops['tp'] = state_ops.assign_add(true_positives, + math_ops.reduce_sum( + is_true_positive, 1)) values['tp'] = true_positives if 'fn' in includes: @@ -887,8 +957,9 @@ def _streaming_confusion_matrix_at_thresholds( math_ops.logical_and(label_is_pos, pred_is_neg)) if weights_tiled is not None: is_false_negative *= weights_tiled - update_ops['fn'] = state_ops.assign_add( - false_negatives, math_ops.reduce_sum(is_false_negative, 1)) + update_ops['fn'] = state_ops.assign_add(false_negatives, + math_ops.reduce_sum( + is_false_negative, 1)) values['fn'] = false_negatives if 'tn' in includes: @@ -897,8 +968,9 @@ def _streaming_confusion_matrix_at_thresholds( math_ops.logical_and(label_is_neg, pred_is_neg)) if weights_tiled is not None: is_true_negative *= weights_tiled - update_ops['tn'] = state_ops.assign_add( - true_negatives, math_ops.reduce_sum(is_true_negative, 1)) + update_ops['tn'] = state_ops.assign_add(true_negatives, + math_ops.reduce_sum( + is_true_negative, 1)) values['tn'] = true_negatives if 'fp' in includes: @@ -907,36 +979,45 @@ def _streaming_confusion_matrix_at_thresholds( math_ops.logical_and(label_is_neg, pred_is_pos)) if weights_tiled is not None: is_false_positive *= weights_tiled - update_ops['fp'] = state_ops.assign_add( - false_positives, math_ops.reduce_sum(is_false_positive, 1)) + update_ops['fp'] = state_ops.assign_add(false_positives, + math_ops.reduce_sum( + is_false_positive, 1)) values['fp'] = false_positives return values, update_ops -def streaming_true_positives_at_thresholds( - predictions, labels, thresholds, weights=None): +def streaming_true_positives_at_thresholds(predictions, + labels, + thresholds, + weights=None): values, update_ops = _streaming_confusion_matrix_at_thresholds( predictions, labels, thresholds, weights=weights, includes=('tp',)) return values['tp'], update_ops['tp'] -def streaming_false_negatives_at_thresholds( - predictions, labels, thresholds, weights=None): +def streaming_false_negatives_at_thresholds(predictions, + labels, + thresholds, + weights=None): values, update_ops = _streaming_confusion_matrix_at_thresholds( predictions, labels, thresholds, weights=weights, includes=('fn',)) return values['fn'], update_ops['fn'] -def streaming_false_positives_at_thresholds( - predictions, labels, thresholds, weights=None): +def streaming_false_positives_at_thresholds(predictions, + labels, + thresholds, + weights=None): values, update_ops = _streaming_confusion_matrix_at_thresholds( predictions, labels, thresholds, weights=weights, includes=('fp',)) return values['fp'], update_ops['fp'] -def streaming_true_negatives_at_thresholds( - predictions, labels, thresholds, weights=None): +def streaming_true_negatives_at_thresholds(predictions, + labels, + thresholds, + weights=None): values, update_ops = _streaming_confusion_matrix_at_thresholds( predictions, labels, thresholds, weights=weights, includes=('tn',)) return values['tn'], update_ops['tn'] @@ -996,8 +1077,8 @@ def streaming_curve_points(labels=None, either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope(name, 'curve_points', (labels, predictions, - weights)): + with variable_scope.variable_scope(name, 'curve_points', + (labels, predictions, weights)): if curve != 'ROC' and curve != 'PR': raise ValueError('curve must be either ROC or PR, %s unknown' % (curve)) kepsilon = 1e-7 # to account for floating point imprecisions @@ -1038,9 +1119,14 @@ def streaming_curve_points(labels=None, return points, update_op -def streaming_auc(predictions, labels, weights=None, num_thresholds=200, - metrics_collections=None, updates_collections=None, - curve='ROC', name=None): +def streaming_auc(predictions, + labels, + weights=None, + num_thresholds=200, + metrics_collections=None, + updates_collections=None, + curve='ROC', + name=None): """Computes the approximate AUC via a Riemann sum. The `streaming_auc` function creates four local variables, `true_positives`, @@ -1097,14 +1183,24 @@ def streaming_auc(predictions, labels, weights=None, num_thresholds=200, tuple. """ return metrics.auc( - predictions=predictions, labels=labels, weights=weights, - metrics_collections=metrics_collections, num_thresholds=num_thresholds, - curve=curve, updates_collections=updates_collections, name=name) + predictions=predictions, + labels=labels, + weights=weights, + metrics_collections=metrics_collections, + num_thresholds=num_thresholds, + curve=curve, + updates_collections=updates_collections, + name=name) -def streaming_specificity_at_sensitivity( - predictions, labels, sensitivity, weights=None, num_thresholds=200, - metrics_collections=None, updates_collections=None, name=None): +def streaming_specificity_at_sensitivity(predictions, + labels, + sensitivity, + weights=None, + num_thresholds=200, + metrics_collections=None, + updates_collections=None, + name=None): """Computes the specificity at a given sensitivity. The `streaming_specificity_at_sensitivity` function creates four local @@ -1154,15 +1250,24 @@ def streaming_specificity_at_sensitivity( or `updates_collections` are not a list or tuple. """ return metrics.specificity_at_sensitivity( - sensitivity=sensitivity, num_thresholds=num_thresholds, - predictions=predictions, labels=labels, weights=weights, + sensitivity=sensitivity, + num_thresholds=num_thresholds, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_sensitivity_at_specificity( - predictions, labels, specificity, weights=None, num_thresholds=200, - metrics_collections=None, updates_collections=None, name=None): +def streaming_sensitivity_at_specificity(predictions, + labels, + specificity, + weights=None, + num_thresholds=200, + metrics_collections=None, + updates_collections=None, + name=None): """Computes the sensitivity at a given specificity. The `streaming_sensitivity_at_specificity` function creates four local @@ -1212,16 +1317,23 @@ def streaming_sensitivity_at_specificity( or `updates_collections` are not a list or tuple. """ return metrics.sensitivity_at_specificity( - specificity=specificity, num_thresholds=num_thresholds, - predictions=predictions, labels=labels, weights=weights, + specificity=specificity, + num_thresholds=num_thresholds, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_precision_at_thresholds(predictions, labels, thresholds, +def streaming_precision_at_thresholds(predictions, + labels, + thresholds, weights=None, metrics_collections=None, - updates_collections=None, name=None): + updates_collections=None, + name=None): """Computes precision values for different `thresholds` on `predictions`. The `streaming_precision_at_thresholds` function creates four local variables, @@ -1266,14 +1378,21 @@ def streaming_precision_at_thresholds(predictions, labels, thresholds, """ return metrics.precision_at_thresholds( thresholds=thresholds, - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_recall_at_thresholds(predictions, labels, thresholds, - weights=None, metrics_collections=None, - updates_collections=None, name=None): +def streaming_recall_at_thresholds(predictions, + labels, + thresholds, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): """Computes various recall values for different `thresholds` on `predictions`. The `streaming_recall_at_thresholds` function creates four local variables, @@ -1316,14 +1435,21 @@ def streaming_recall_at_thresholds(predictions, labels, thresholds, """ return metrics.recall_at_thresholds( thresholds=thresholds, - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_false_positive_rate_at_thresholds( - predictions, labels, thresholds, weights=None, metrics_collections=None, - updates_collections=None, name=None): +def streaming_false_positive_rate_at_thresholds(predictions, + labels, + thresholds, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): """Computes various fpr values for different `thresholds` on `predictions`. The `streaming_false_positive_rate_at_thresholds` function creates two @@ -1365,20 +1491,19 @@ def streaming_false_positive_rate_at_thresholds( either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope( - name, 'false_positive_rate_at_thresholds', - (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'false_positive_rate_at_thresholds', + (predictions, labels, weights)): values, update_ops = _streaming_confusion_matrix_at_thresholds( predictions, labels, thresholds, weights, includes=('fp', 'tn')) # Avoid division by zero. epsilon = 1e-7 + def compute_fpr(fp, tn, name): return math_ops.div(fp, epsilon + fp + tn, name='fpr_' + name) fpr = compute_fpr(values['fp'], values['tn'], 'value') - update_op = compute_fpr( - update_ops['fp'], update_ops['tn'], 'update_op') + update_op = compute_fpr(update_ops['fp'], update_ops['tn'], 'update_op') if metrics_collections: ops.add_to_collections(metrics_collections, fpr) @@ -1389,9 +1514,13 @@ def streaming_false_positive_rate_at_thresholds( return fpr, update_op -def streaming_false_negative_rate_at_thresholds( - predictions, labels, thresholds, weights=None, metrics_collections=None, - updates_collections=None, name=None): +def streaming_false_negative_rate_at_thresholds(predictions, + labels, + thresholds, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): """Computes various fnr values for different `thresholds` on `predictions`. The `streaming_false_negative_rate_at_thresholds` function creates two @@ -1433,20 +1562,19 @@ def streaming_false_negative_rate_at_thresholds( either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope( - name, 'false_negative_rate_at_thresholds', - (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'false_negative_rate_at_thresholds', + (predictions, labels, weights)): values, update_ops = _streaming_confusion_matrix_at_thresholds( predictions, labels, thresholds, weights, includes=('fn', 'tp')) # Avoid division by zero. epsilon = 1e-7 + def compute_fnr(fn, tp, name): return math_ops.div(fn, epsilon + fn + tp, name='fnr_' + name) fnr = compute_fnr(values['fn'], values['tp'], 'value') - update_op = compute_fnr( - update_ops['fn'], update_ops['tp'], 'update_op') + update_op = compute_fnr(update_ops['fn'], update_ops['tp'], 'update_op') if metrics_collections: ops.add_to_collections(metrics_collections, fnr) @@ -1469,8 +1597,12 @@ def _at_k_name(name, k=None, class_id=None): @deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, ' 'and reshape labels from [batch_size] to [batch_size, 1].') -def streaming_recall_at_k(predictions, labels, k, weights=None, - metrics_collections=None, updates_collections=None, +def streaming_recall_at_k(predictions, + labels, + k, + weights=None, + metrics_collections=None, + updates_collections=None, name=None): """Computes the recall@k of the predictions with respect to dense labels. @@ -1516,11 +1648,8 @@ def streaming_recall_at_k(predictions, labels, k, weights=None, tuple. """ in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k)) - return streaming_mean(in_top_k, - weights, - metrics_collections, - updates_collections, - name or _at_k_name('recall', k)) + return streaming_mean(in_top_k, weights, metrics_collections, + updates_collections, name or _at_k_name('recall', k)) # TODO(ptucker): Validate range of values in labels? @@ -1599,10 +1728,14 @@ def streaming_sparse_recall_at_k(predictions, are not a list or tuple. """ return metrics.recall_at_k( - k=k, class_id=class_id, - predictions=predictions, labels=labels, weights=weights, + k=k, + class_id=class_id, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) # TODO(ptucker): Validate range of values in labels? @@ -1684,10 +1817,14 @@ def streaming_sparse_precision_at_k(predictions, are not a list or tuple. """ return metrics.sparse_precision_at_k( - k=k, class_id=class_id, - predictions=predictions, labels=labels, weights=weights, + k=k, + class_id=class_id, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) # TODO(ptucker): Validate range of values in labels? @@ -1766,9 +1903,8 @@ def streaming_sparse_precision_at_top_k(top_k_predictions, ValueError: If `top_k_predictions` has rank < 2. """ default_name = _at_k_name('precision', class_id=class_id) - with ops.name_scope( - name, default_name, - (top_k_predictions, labels, weights)) as name_scope: + with ops.name_scope(name, default_name, + (top_k_predictions, labels, weights)) as name_scope: return metrics_impl._sparse_precision_at_top_k( # pylint: disable=protected-access labels=labels, predictions_idx=top_k_predictions, @@ -1848,8 +1984,8 @@ def sparse_recall_at_top_k(labels, are not a list or tuple. """ default_name = _at_k_name('recall', class_id=class_id) - with ops.name_scope(name, default_name, (top_k_predictions, labels, - weights)) as name_scope: + with ops.name_scope(name, default_name, + (top_k_predictions, labels, weights)) as name_scope: return metrics_impl._sparse_recall_at_top_k( # pylint: disable=protected-access labels=labels, predictions_idx=top_k_predictions, @@ -1919,9 +2055,13 @@ def streaming_sparse_average_precision_at_k(predictions, value matches `metric`. """ return metrics.sparse_average_precision_at_k( - k=k, predictions=predictions, labels=labels, weights=weights, + k=k, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) def streaming_sparse_average_precision_at_top_k(top_k_predictions, @@ -1987,7 +2127,9 @@ def streaming_sparse_average_precision_at_top_k(top_k_predictions, name=name) -def streaming_mean_absolute_error(predictions, labels, weights=None, +def streaming_mean_absolute_error(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -2035,12 +2177,18 @@ def streaming_mean_absolute_error(predictions, labels, weights=None, tuple. """ return metrics.mean_absolute_error( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_mean_relative_error(predictions, labels, normalizer, weights=None, +def streaming_mean_relative_error(predictions, + labels, + normalizer, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -2089,12 +2237,18 @@ def streaming_mean_relative_error(predictions, labels, normalizer, weights=None, tuple. """ return metrics.mean_relative_error( - normalizer=normalizer, predictions=predictions, labels=labels, - weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + normalizer=normalizer, + predictions=predictions, + labels=labels, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) -def streaming_mean_squared_error(predictions, labels, weights=None, +def streaming_mean_squared_error(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -2142,12 +2296,17 @@ def streaming_mean_squared_error(predictions, labels, weights=None, tuple. """ return metrics.mean_squared_error( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_root_mean_squared_error(predictions, labels, weights=None, +def streaming_root_mean_squared_error(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -2195,9 +2354,12 @@ def streaming_root_mean_squared_error(predictions, labels, weights=None, tuple. """ return metrics.root_mean_squared_error( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) def streaming_covariance(predictions, @@ -2253,8 +2415,8 @@ def streaming_covariance(predictions, ValueError: If labels and predictions are of different sizes or if either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope( - name, 'covariance', (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'covariance', + (predictions, labels, weights)): predictions, labels, weights = _remove_squeezable_dimensions( predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) @@ -2298,22 +2460,22 @@ def streaming_covariance(predictions, # prev_mean_label is E[y_A] in the update equation prev_mean_label = update_mean_label - delta_mean_label - unweighted_batch_coresiduals = ( - (predictions - batch_mean_prediction) * (labels - batch_mean_label)) + unweighted_batch_coresiduals = ((predictions - batch_mean_prediction) * + (labels - batch_mean_label)) # batch_comoment is C_B in the update equation if weights is None: batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals) else: - batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals * - weights) + batch_comoment = math_ops.reduce_sum( + unweighted_batch_coresiduals * weights) # View delta_comoment as = C_AB - C_A in the update equation above. # Since C_A is stored in a var, by how much do we need to increment that var # to make the var = C_AB? - delta_comoment = (batch_comoment + - (prev_mean_prediction - batch_mean_prediction) * - (prev_mean_label - batch_mean_label) * - (prev_count * batch_count / update_count)) + delta_comoment = ( + batch_comoment + (prev_mean_prediction - batch_mean_prediction) * + (prev_mean_label - batch_mean_label) * + (prev_count * batch_count / update_count)) update_comoment = state_ops.assign_add(comoment, delta_comoment) covariance = array_ops.where( @@ -2387,8 +2549,8 @@ def streaming_pearson_correlation(predictions, `weights` is the wrong size, or if either `metrics_collections` or `updates_collections` are not a `list` or `tuple`. """ - with variable_scope.variable_scope( - name, 'pearson_r', (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'pearson_r', + (predictions, labels, weights)): predictions, labels, weights = _remove_squeezable_dimensions( predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) @@ -2405,13 +2567,14 @@ def streaming_pearson_correlation(predictions, pearson_r = math_ops.truediv( cov, - math_ops.multiply(math_ops.sqrt(var_predictions), - math_ops.sqrt(var_labels)), + math_ops.multiply( + math_ops.sqrt(var_predictions), math_ops.sqrt(var_labels)), name='pearson_r') update_op = math_ops.truediv( update_cov, - math_ops.multiply(math_ops.sqrt(update_var_predictions), - math_ops.sqrt(update_var_labels)), + math_ops.multiply( + math_ops.sqrt(update_var_predictions), + math_ops.sqrt(update_var_labels)), name='update_op') if metrics_collections: @@ -2425,7 +2588,10 @@ def streaming_pearson_correlation(predictions, # TODO(nsilberman): add a 'normalized' flag so that the user can request # normalization if the inputs are not normalized. -def streaming_mean_cosine_distance(predictions, labels, dim, weights=None, +def streaming_mean_cosine_distance(predictions, + labels, + dim, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -2471,12 +2637,11 @@ def streaming_mean_cosine_distance(predictions, labels, dim, weights=None, predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) radial_diffs = math_ops.multiply(predictions, labels) - radial_diffs = math_ops.reduce_sum(radial_diffs, - reduction_indices=[dim,], - keep_dims=True) - mean_distance, update_op = streaming_mean(radial_diffs, weights, - None, - None, + radial_diffs = math_ops.reduce_sum( + radial_diffs, reduction_indices=[ + dim, + ], keep_dims=True) + mean_distance, update_op = streaming_mean(radial_diffs, weights, None, None, name or 'mean_cosine_distance') mean_distance = math_ops.subtract(1.0, mean_distance) update_op = math_ops.subtract(1.0, update_op) @@ -2490,7 +2655,9 @@ def streaming_mean_cosine_distance(predictions, labels, dim, weights=None, return mean_distance, update_op -def streaming_percentage_less(values, threshold, weights=None, +def streaming_percentage_less(values, + threshold, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -2530,9 +2697,12 @@ def streaming_percentage_less(values, threshold, weights=None, or tuple. """ return metrics.percentage_below( - values=values, threshold=threshold, weights=weights, + values=values, + threshold=threshold, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) def streaming_mean_iou(predictions, @@ -2584,9 +2754,13 @@ def streaming_mean_iou(predictions, tuple. """ return metrics.mean_iou( - num_classes=num_classes, predictions=predictions, labels=labels, - weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + num_classes=num_classes, + predictions=predictions, + labels=labels, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) def _next_array_size(required_size, growth_factor=1.5): @@ -2601,9 +2775,9 @@ def _next_array_size(required_size, growth_factor=1.5): tf.Tensor with dtype=int32 giving the next array size. """ exponent = math_ops.ceil( - math_ops.log(math_ops.cast(required_size, dtypes.float32)) - / math_ops.log(math_ops.cast(growth_factor, dtypes.float32))) - return math_ops.cast(math_ops.ceil(growth_factor ** exponent), dtypes.int32) + math_ops.log(math_ops.cast(required_size, dtypes.float32)) / math_ops.log( + math_ops.cast(growth_factor, dtypes.float32))) + return math_ops.cast(math_ops.ceil(growth_factor**exponent), dtypes.int32) def streaming_concat(values, @@ -2660,8 +2834,7 @@ def streaming_concat(values, if not 0 <= axis < ndim: raise ValueError('axis = %r not in [0, %r)' % (axis, ndim)) - fixed_shape = [dim.value for n, dim in enumerate(values_shape) - if n != axis] + fixed_shape = [dim.value for n, dim in enumerate(values_shape) if n != axis] if any(value is None for value in fixed_shape): raise ValueError('all dimensions of `values` other than the dimension to ' 'concatenate along must have statically known size') @@ -2804,14 +2977,14 @@ def _remove_squeezable_dimensions(predictions, labels, weights): # Use static rank. if weights_rank - predictions_rank == 1: weights = array_ops.squeeze(weights, [-1]) - elif (weights_rank is None) or ( - weights_shape.dims[-1].is_compatible_with(1)): + elif (weights_rank is + None) or (weights_shape.dims[-1].is_compatible_with(1)): # Use dynamic rank weights = control_flow_ops.cond( - math_ops.equal(array_ops.rank(weights), - math_ops.add(array_ops.rank(predictions), 1)), - lambda: array_ops.squeeze(weights, [-1]), - lambda: weights) + math_ops.equal( + array_ops.rank(weights), + math_ops.add(array_ops.rank(predictions), 1)), + lambda: array_ops.squeeze(weights, [-1]), lambda: weights) return predictions, labels, weights From fafff08cbc3b952d60ee98914c234bb6af09b968 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Oct 2017 11:57:40 -0700 Subject: [PATCH 33/41] Adds the k-MC2 algorithm for efficient seeding of mini batch k-means in TensorFlow. PiperOrigin-RevId: 172914154 --- .../contrib/factorization/g3doc/kmeans.md | 12 +- .../factorization/kernels/clustering_ops.cc | 52 +++++++ .../kernels/clustering_ops_test.cc | 56 ++++++++ .../factorization/ops/clustering_ops.cc | 19 +++ .../kernel_tests/clustering_ops_test.py | 57 ++++++++ .../python/ops/clustering_ops.py | 127 ++++++++++++++++-- 6 files changed, 307 insertions(+), 16 deletions(-) diff --git a/tensorflow/contrib/factorization/g3doc/kmeans.md b/tensorflow/contrib/factorization/g3doc/kmeans.md index b55c9d09ad3..c1843f0bf07 100644 --- a/tensorflow/contrib/factorization/g3doc/kmeans.md +++ b/tensorflow/contrib/factorization/g3doc/kmeans.md @@ -24,7 +24,11 @@ the full-batch version. approach for computing the initial cluster assignments that is expensive but is typically less prone to getting stuck in bad local minima. -We provide distributed implementations of both full-batch and mini-batch -K-Means algorithm. Both K-Means++ and random initialization are supported. -The user can also choose between **Cosine** and **Squared Euclidean** distance -metrics. +**[k-MC2](https://www.aaai.org/ocs/index.php/AAAI/AAAI16/paper/view/12147/11759)** +provides a very fast seeding method that provides high quality centers +comparable to K-Means++ seeding. k-MC2 works particularly well if it is combined +with Mini-batch K-Means. + +We provide distributed implementations of both full-batch and mini-batch K-Means +algorithm. K-Means++, k-MC2 and random initialization are supported. The user +can also choose between **Cosine** and **Squared Euclidean** distance metrics. diff --git a/tensorflow/contrib/factorization/kernels/clustering_ops.cc b/tensorflow/contrib/factorization/kernels/clustering_ops.cc index a2136c08bbc..dd61f59585a 100644 --- a/tensorflow/contrib/factorization/kernels/clustering_ops.cc +++ b/tensorflow/contrib/factorization/kernels/clustering_ops.cc @@ -224,6 +224,58 @@ class KmeansPlusPlusInitializationOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("KmeansPlusPlusInitialization").Device(DEVICE_CPU), KmeansPlusPlusInitializationOp); +// Implementation of one single Markov Chain for the k-MC^2 algorithm +class KMC2ChainInitializationOp : public OpKernel { + public: + explicit KMC2ChainInitializationOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, + context->MatchSignature({DT_FLOAT, DT_INT64}, {DT_INT64})); + } + + void Compute(OpKernelContext* context) override { + const Tensor& distances_tensor = context->input(0); + const Tensor& seed_tensor = context->input(1); + OP_REQUIRES(context, TensorShapeUtils::IsVector(distances_tensor.shape()), + InvalidArgument("Input distances should be a vector.")); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()), + InvalidArgument("Input seed should be a scalar.")); + const int64 num_points = distances_tensor.dim_size(0); + const int64 seed = seed_tensor.scalar()(); + OP_REQUIRES(context, num_points > 0, + InvalidArgument("Expected distances_tensor.size() > 0.")); + + random::PhiloxRandom random(seed); + random::SimplePhilox rng(&random); + + auto distances = distances_tensor.flat(); + // Set the initial state of the Markov chain to be the first candidate. + int64 selected_index = 0; + float selected_distance = distances(selected_index); + // Build a Markov chain of length num_points. + for (int64 i = 1; i < num_points; ++i) { + const float candidate_distance = distances(i); + // Set the next state of the Markov chain to be the candidate with + // probability min(1, candidate_distance/selected_distance). + if (candidate_distance > rng.RandFloat() * selected_distance) { + selected_index = i; + selected_distance = candidate_distance; + } + } + + Tensor* output_sampled_index_tensor; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({}), + &output_sampled_index_tensor)); + auto output = output_sampled_index_tensor->scalar(); + // Return the last state of the Markov chain as the new center. + output() = selected_index; + } +}; + +REGISTER_KERNEL_BUILDER(Name("KMC2ChainInitialization").Device(DEVICE_CPU), + KMC2ChainInitializationOp); + // Operator for computing the nearest neighbors for a set of points. class NearestNeighborsOp : public OpKernel { public: diff --git a/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc b/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc index c4a96b048db..8172a7cebb8 100644 --- a/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc +++ b/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc @@ -116,6 +116,62 @@ RUN_BM_KmeansPlusPlusInitialization(k3RetriesPerSample); #undef RUN_BM_KmeansPlusPlusInitialization #undef BENCHMARK_KMEANS_PLUS_PLUS +Graph* SetUpKMC2Initialization(int num_points) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor distances(DT_FLOAT, TensorShape({num_points})); + Tensor seed(DT_INT64, TensorShape({})); + distances.flat().setRandom(); + seed.flat().setConstant(12345); + + TF_CHECK_OK( + NodeBuilder("KMC2ChainInitializationOp", "KMC2ChainInitialization") + .Input(test::graph::Constant(g, distances)) + .Input(test::graph::Constant(g, seed)) + .Finalize(g, nullptr /* node */)); + return g; +} + +template +void BM_KMC2Initialization(int iters) { + testing::StopTiming(); + testing::ItemsProcessed(static_cast(iters) * num_points * num_dims * + num_to_sample); + testing::UseRealTime(); + Graph* g = SetUpKMC2Initialization(num_points); + testing::StartTiming(); + test::Benchmark("cpu", g).Run(iters); +} +#define BENCHMARK_KMC2(p, c, d) \ + void BM_KMC2Initialization_##p##_##c##_##d(int iters) { \ + BM_KMC2Initialization(iters); \ + } \ + BENCHMARK(BM_KMC2Initialization_##p##_##c##_##d); + +#define RUN_BM_KMC2Initialization \ + BENCHMARK_KMC2(k10Points, k2Centers, k100Dim); \ + BENCHMARK_KMC2(k10Points, k5Centers, k100Dim); \ + BENCHMARK_KMC2(k10Points, k10Centers, k100Dim); \ + BENCHMARK_KMC2(k100Points, k10Centers, k100Dim); \ + BENCHMARK_KMC2(k100Points, k20Centers, k100Dim); \ + BENCHMARK_KMC2(k100Points, k50Centers, k100Dim); \ + BENCHMARK_KMC2(k100Points, k100Centers, k100Dim); \ + BENCHMARK_KMC2(k1kPoints, k100Centers, k100Dim); \ + BENCHMARK_KMC2(k1kPoints, k200Centers, k100Dim); \ + BENCHMARK_KMC2(k1kPoints, k500Centers, k100Dim); \ + BENCHMARK_KMC2(k1kPoints, k1kCenters, k100Dim); \ + BENCHMARK_KMC2(k10kPoints, k100Centers, k100Dim); \ + BENCHMARK_KMC2(k10kPoints, k200Centers, k100Dim); \ + BENCHMARK_KMC2(k10kPoints, k500Centers, k100Dim); \ + BENCHMARK_KMC2(k10kPoints, k1kCenters, k100Dim); \ + BENCHMARK_KMC2(k1MPoints, k100Centers, k100Dim); \ + BENCHMARK_KMC2(k1MPoints, k200Centers, k100Dim); \ + BENCHMARK_KMC2(k1MPoints, k500Centers, k100Dim); \ + BENCHMARK_KMC2(k1MPoints, k1kCenters, k100Dim) + +RUN_BM_KMC2Initialization; +#undef RUN_BM_KMC2Initialization +#undef BENCHMARK_KMC2 + Graph* SetUpNearestNeighbors(int num_dims, int num_points, int num_centers, int k) { Graph* g = new Graph(OpRegistry::Global()); diff --git a/tensorflow/contrib/factorization/ops/clustering_ops.cc b/tensorflow/contrib/factorization/ops/clustering_ops.cc index f2dfcf7ed0f..2686702c1d5 100644 --- a/tensorflow/contrib/factorization/ops/clustering_ops.cc +++ b/tensorflow/contrib/factorization/ops/clustering_ops.cc @@ -44,6 +44,25 @@ num_retries_per_sample: Scalar. For each row that is sampled, this parameter samples: Matrix of shape (num_to_sample, d). The sampled rows. )"); +REGISTER_OP("KMC2ChainInitialization") + .Input("distances: float32") + .Input("seed: int64") + .Output("index: int64") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"( +Returns the index of a data point that should be added to the seed set. + +Entries in distances are assumed to be squared distances of candidate points to +the already sampled centers in the seed set. The op constructs one Markov chain +of the k-MC^2 algorithm and returns the index of one candidate point to be added +as an additional cluster center. + +distances: Vector with squared distances to the closest previously sampled + cluster center for each candidate point. +seed: Scalar. Seed for initializing the random number generator. +index: Scalar with the index of the sampled point. +)"); + REGISTER_OP("NearestNeighbors") .Input("points: float32") .Input("centers: float32") diff --git a/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py index 450f64063a2..1322f7ce5f8 100644 --- a/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py +++ b/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py @@ -55,6 +55,63 @@ class KmeansPlusPlusInitializationTest(test.TestCase): self.runTestWithSeed(seed) +class KMC2InitializationTest(test.TestCase): + + def runTestWithSeed(self, seed): + with self.test_session(): + distances = np.zeros(1000).astype(np.float32) + distances[6] = 10e7 + distances[4] = 10e3 + + sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed) + self.assertEquals(sampled_point.eval(), 6) + distances[6] = 0.0 + sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed) + self.assertEquals(sampled_point.eval(), 4) + + def testBasic(self): + for seed in range(100): + self.runTestWithSeed(seed) + + +class KMC2InitializationLargeTest(test.TestCase): + + def setUp(self): + self._distances = np.zeros(1001) + self._distances[500] = 100.0 + self._distances[1000] = 50.0 + + def testBasic(self): + with self.test_session(): + counts = {} + seed = 0 + for i in range(50): + sample = clustering_ops.kmc2_chain_initialization( + self._distances, seed + i).eval() + counts[sample] = counts.get(sample, 0) + 1 + self.assertEquals(len(counts), 2) + self.assertTrue(500 in counts) + self.assertTrue(1000 in counts) + self.assertGreaterEqual(counts[500], 5) + self.assertGreaterEqual(counts[1000], 5) + + +class KMC2InitializationCornercaseTest(test.TestCase): + + def setUp(self): + self._distances = np.zeros(10) + + def runTestWithSeed(self, seed): + with self.test_session(): + sampled_point = clustering_ops.kmc2_chain_initialization( + self._distances, seed) + self.assertEquals(sampled_point.eval(), 0) + + def testBasic(self): + for seed in range(100): + self.runTestWithSeed(seed) + + # A simple test that can be verified by hand. class NearestCentersTest(test.TestCase): diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py index d7320aeb3de..96cc80ce241 100644 --- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py +++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py @@ -50,6 +50,7 @@ COSINE_DISTANCE = 'cosine' RANDOM_INIT = 'random' KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus' +KMC2_INIT = 'kmc2' # The name of the variable holding the cluster centers. Used by the Estimator. CLUSTERS_VAR_NAME = 'clusters' @@ -66,7 +67,8 @@ class KMeans(object): use_mini_batch=False, mini_batch_steps_per_iteration=1, random_seed=0, - kmeans_plus_plus_num_retries=2): + kmeans_plus_plus_num_retries=2, + kmc2_chain_length=200): """Creates an object for generating KMeans clustering graph. This class implements the following variants of K-means algorithm: @@ -95,7 +97,8 @@ class KMeans(object): exactly like a full-batch version. Args: - inputs: An input tensor or list of input tensors + inputs: An input tensor or list of input tensors. It is assumed that the + data points have been previously randomly permuted. num_clusters: An integer tensor specifying the number of clusters. This argument is ignored if initial_clusters is a tensor or numpy array. initial_clusters: Specifies the clusters used during initialization. One @@ -104,6 +107,7 @@ class KMeans(object): - a function f(inputs, k) that returns up to k centers from `inputs`. - "random": Choose centers randomly from `inputs`. - "kmeans_plus_plus": Use kmeans++ to choose centers from `inputs`. + - "kmc2": Use the fast k-MC2 algorithm to choose centers from `inputs`. In the last three cases, one batch of `inputs` may not yield `num_clusters` centers, in which case initialization will require multiple batches until enough centers are chosen. In the case of @@ -121,13 +125,17 @@ class KMeans(object): additional points to draw from the current distribution before selecting the best. If a negative value is specified, a heuristic is used to sample O(log(num_to_sample)) additional points. + kmc2_chain_length: Determines how many candidate points are used by the + k-MC2 algorithm to produce one new cluster centers. If a (mini-)batch + contains less points, one new cluster center is generated from the + (mini-)batch. Raises: ValueError: An invalid argument was passed to initial_clusters or distance_metric. """ if isinstance(initial_clusters, str) and initial_clusters not in [ - RANDOM_INIT, KMEANS_PLUS_PLUS_INIT + RANDOM_INIT, KMEANS_PLUS_PLUS_INIT, KMC2_INIT ]: raise ValueError( "Unsupported initialization algorithm '%s'" % initial_clusters) @@ -141,6 +149,7 @@ class KMeans(object): self._mini_batch_steps_per_iteration = int(mini_batch_steps_per_iteration) self._random_seed = random_seed self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries + self._kmc2_chain_length = kmc2_chain_length @classmethod def _distance_graph(cls, inputs, clusters, distance_metric): @@ -302,9 +311,10 @@ class KMeans(object): else: cluster_centers_updated = cluster_centers update_in_steps = None - cluster_counts = (variable_scope.variable( - array_ops.ones([num_clusters], dtype=dtypes.int64)) - if self._use_mini_batch else None) + cluster_counts = ( + variable_scope.variable( + array_ops.ones([num_clusters], dtype=dtypes.int64)) + if self._use_mini_batch else None) return (cluster_centers, cluster_centers_initialized, cluster_counts, cluster_centers_updated, update_in_steps) @@ -359,7 +369,7 @@ class KMeans(object): init_op = _InitializeClustersOpFactory( self._inputs, num_clusters, initial_clusters, self._distance_metric, self._random_seed, self._kmeans_plus_plus_num_retries, - cluster_centers_var, cluster_centers_updated, + self._kmc2_chain_length, cluster_centers_var, cluster_centers_updated, cluster_centers_initialized).op() cluster_centers = cluster_centers_var @@ -520,8 +530,9 @@ class KMeans(object): array_ops.reshape(array_ops.shape(inp)[0], [-1])), [-1, 1]), cluster_idx, num_clusters)) with ops.colocate_with(cluster_centers, ignore_existing=True): - new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast( - math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon) + new_clusters_centers = math_ops.add_n(cluster_sums) / ( + math_ops.cast(math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + + epsilon) if self._clusters_l2_normalized(): new_clusters_centers = nn_impl.l2_normalize(new_clusters_centers, dim=1) return state_ops.assign(cluster_centers, new_clusters_centers) @@ -548,9 +559,12 @@ class _InitializeClustersOpFactory(object): cluster_centers_initialized := true """ + # TODO(ccolby): Refactor this class so that kmc2 isn't so much a special case. + def __init__(self, inputs, num_clusters, initial_clusters, distance_metric, - random_seed, kmeans_plus_plus_num_retries, cluster_centers, - cluster_centers_updated, cluster_centers_initialized): + random_seed, kmeans_plus_plus_num_retries, kmc2_chain_length, + cluster_centers, cluster_centers_updated, + cluster_centers_initialized): """Creates an op factory. Args: @@ -560,6 +574,7 @@ class _InitializeClustersOpFactory(object): distance_metric: See KMeans constructor. random_seed: See KMeans constructor. kmeans_plus_plus_num_retries: See KMeans constructor. + kmc2_chain_length: See KMeans constructor. cluster_centers: The TF variable holding the initial centers. It may already contain some centers when the op is executed. cluster_centers_updated: A second TF variable to hold a copy of the @@ -575,6 +590,7 @@ class _InitializeClustersOpFactory(object): self._distance_metric = distance_metric self._random_seed = random_seed self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries + self._kmc2_chain_length = kmc2_chain_length self._cluster_centers = cluster_centers self._cluster_centers_updated = cluster_centers_updated self._cluster_centers_initialized = cluster_centers_initialized @@ -604,6 +620,90 @@ class _InitializeClustersOpFactory(object): math_ops.to_int64(self._num_remaining), self._random_seed, self._kmeans_plus_plus_num_retries) + def _kmc2_multiple_centers(self): + """Adds new initial cluster centers using the k-MC2 algorithm. + + In each call to the op, the provided batch is split into subsets based on + the specified `kmc2_chain_length`. On each subset, a single Markov chain of + the k-MC2 algorithm is used to add *one* new center cluster center. If there + are less than `kmc2_chain_length` points in the subset, a single center is + added using one Markov chain on the full input. It is assumed that the + provided batch has previously been randomly permuted. Otherwise, k-MC2 may + return suboptimal centers. + + Returns: + An op that adds new cluster centers. + """ + # The op only operates on the first shard of data. + first_shard = self._inputs[0] + # Number of points in the input that can be used. + batch_size = array_ops.shape(first_shard)[0] + # Maximum number of subsets such that the size of each subset is at least + # `kmc2_chain_length`. Final subsets may be larger. + max_to_sample = math_ops.cast( + batch_size / self._kmc2_chain_length, dtype=dtypes.int32) + # We sample at least one new center and at most all remaining centers. + num_to_sample = math_ops.maximum( + math_ops.minimum(self._num_remaining, max_to_sample), 1) + + def _cond(i, _): + """Stopping condition for the while loop.""" + return math_ops.less(i, num_to_sample) + + def _body(i, _): + """Body that adds a single new center based on a subset.""" + + def _sample_random(): + """Returns a random point as a cluster center.""" + # By assumption the batch is reshuffled and _sample_random is always + # called for i=0. Hence, we simply return the first point. + new_center = array_ops.reshape(first_shard[0], [1, -1]) + if self._distance_metric == COSINE_DISTANCE: + new_center = nn_impl.l2_normalize(new_center, dim=1) + return new_center + + def _sample_kmc2_chain(): + """Returns previous centers as well as a new center sampled using k-MC2. + """ + # Extract the subset from the underlying batch. + start = i * self._kmc2_chain_length + end = start + self._kmc2_chain_length + subset = first_shard[start:end] + # Compute the distances from points in the subset to previous centers. + _, distances = gen_clustering_ops.nearest_neighbors( + subset, self._cluster_centers, 1) + # Sample index of new center using k-MC2 Markov chain. + new_center_index = gen_clustering_ops.kmc2_chain_initialization( + array_ops.squeeze(distances), self._random_seed) + # Extract actual new center. + newly_sampled_center = array_ops.reshape(subset[new_center_index], + [1, -1]) + # Return concatenation with previously sampled centers. + if self._distance_metric == COSINE_DISTANCE: + newly_sampled_center = nn_impl.l2_normalize( + newly_sampled_center, dim=1) + return array_ops.concat([self._cluster_centers, newly_sampled_center], + 0) + + # Obtain a random point if there are no previously sampled centers. + # Otherwise, construct a k-MC2 Markov chain. + new_centers = control_flow_ops.cond( + math_ops.equal(self._num_selected, 0), _sample_random, + _sample_kmc2_chain) + # Assign new cluster centers to underlying variable. + assigned_centers = state_ops.assign( + self._cluster_centers, new_centers, validate_shape=False) + if self._cluster_centers_updated is not self._cluster_centers: + assigned_centers = state_ops.assign( + self._cluster_centers_updated, + assigned_centers, + validate_shape=False) + return i + 1, self._num_clusters - array_ops.shape(assigned_centers)[0] + + # Add num_to_sample new data points. + _, num_remaining = control_flow_ops.while_loop(_cond, _body, [0, 0]) + return num_remaining + def _greedy_batch_sampler(self, sampler): # If the input dataset size is smaller than the number of centers # remaining, choose the entire input dataset as centers. This can happen @@ -657,7 +757,10 @@ class _InitializeClustersOpFactory(object): with ops.control_dependencies([ check_ops.assert_positive(self._num_remaining), ]): - num_now_remaining = self._add_new_centers() + if self._initial_clusters == KMC2_INIT: + num_now_remaining = self._kmc2_multiple_centers() + else: + num_now_remaining = self._add_new_centers() return control_flow_ops.cond( math_ops.equal(num_now_remaining, 0), lambda: state_ops.assign(self._cluster_centers_initialized, True), From 58bdae2f8b7bead45537092f39f6fa6fd15c50d0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Oct 2017 12:08:44 -0700 Subject: [PATCH 34/41] Work around for compiler bug in GCC on Android. PiperOrigin-RevId: 172915900 --- tensorflow/core/kernels/transpose_functor.h | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/kernels/transpose_functor.h b/tensorflow/core/kernels/transpose_functor.h index 9781fe3b618..add4635331e 100644 --- a/tensorflow/core/kernels/transpose_functor.h +++ b/tensorflow/core/kernels/transpose_functor.h @@ -201,17 +201,26 @@ Status DoTransposeImpl(const Device& d, const Tensor& in, case DT_COMPLEX64: if (conjugate) { - Transpose::run(d, in, perm, out); +#if defined(__ANDROID__) and !defined(__clang__) + // Workaround for GCC compiler bug in Android toolchain. + return errors::Unimplemented( + "Conjugate transpose of complex64 not supported for GCC on " + "Android."); +#else + Transpose::run(d, in, perm, out); +#endif } else { - Transpose::run(d, in, perm, out); + Transpose::run(d, in, perm, out); } break; case DT_COMPLEX128: if (conjugate) { - Transpose::run(d, in, perm, out); + Transpose::run(d, in, perm, + out); } else { - Transpose::run(d, in, perm, out); + Transpose::run(d, in, perm, + out); } break; From b68a3f2e445cdc749f380387b910f6eac72e5dcf Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Fri, 20 Oct 2017 12:26:09 -0700 Subject: [PATCH 35/41] Iterating through a map in protobuf is essentially nondeterministic. This CL enables us to traverse the map in a deterministic order by sorting the keys first. PiperOrigin-RevId: 172918084 --- tensorflow/compiler/xla/service/user_computation.cc | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index b3506b72bf5..065d2580c68 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -1843,10 +1844,17 @@ UserComputation::GetEmbeddedComputations( XLA_VLOG_LINES(3, session_computation_.DebugString()); std::vector computations; + std::vector sorted_handles; for (const auto& handle_request : session_computation_.requests()) { - int64 handle_value = handle_request.first; + sorted_handles.push_back(handle_request.first); + } + std::sort(sorted_handles.begin(), sorted_handles.end()); + for (int64 handle : sorted_handles) { + const auto& handle_request = session_computation_.requests().find(handle); + CHECK(handle_request != session_computation_.requests().end()); + int64 handle_value = handle_request->first; if (handle_value <= version) { - const OperationRequest& request = handle_request.second; + const OperationRequest& request = handle_request->second; switch (request.request().op_case()) { case OpRequest::kCallRequest: { CHECK_EQ(1, request.embedded_computation_versions_size()); From aada11e19a1ceb901f490aa3c064f2778cb2acf2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Oct 2017 12:46:29 -0700 Subject: [PATCH 36/41] Exposes the read_batch_size argument to read_batch_features. PiperOrigin-RevId: 172920603 --- .../learn/python/learn/learn_io/graph_io.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py index bdb88b89bb3..4b34fc62849 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py @@ -442,7 +442,8 @@ def read_keyed_batch_features(file_pattern, feature_queue_capacity=100, num_enqueue_threads=2, parse_fn=None, - name=None): + name=None, + read_batch_size=None): """Adds operations to read, queue, batch and parse `Example` protos. Given file pattern (or list of files), will setup a queue for file names, @@ -482,6 +483,8 @@ def read_keyed_batch_features(file_pattern, parse_fn: Parsing function, takes `Example` Tensor returns parsed representation. If `None`, no parsing is done. name: Name of resulting op. + read_batch_size: An int or scalar `Tensor` specifying the number of + records to read at once. If `None`, defaults to `batch_size`. Returns: Returns tuple of: @@ -493,6 +496,7 @@ def read_keyed_batch_features(file_pattern, """ with ops.name_scope(name, 'read_batch_features', [file_pattern]) as scope: + if read_batch_size is None: read_batch_size = batch_size keys, examples = read_keyed_batch_examples( file_pattern, batch_size, @@ -501,7 +505,7 @@ def read_keyed_batch_features(file_pattern, num_epochs=num_epochs, queue_capacity=queue_capacity, num_threads=reader_num_threads, - read_batch_size=batch_size, + read_batch_size=read_batch_size, parse_fn=parse_fn, name=scope) # Parse the example. @@ -727,7 +731,8 @@ def read_batch_features(file_pattern, reader_num_threads=1, num_enqueue_threads=2, parse_fn=None, - name=None): + name=None, + read_batch_size=None): """Adds operations to read, queue, batch and parse `Example` protos. Given file pattern (or list of files), will setup a queue for file names, @@ -768,6 +773,8 @@ def read_batch_features(file_pattern, parse_fn: Parsing function, takes `Example` Tensor returns parsed representation. If `None`, no parsing is done. name: Name of resulting op. + read_batch_size: An int or scalar `Tensor` specifying the number of + records to read at once. If `None`, defaults to `batch_size`. Returns: A dict of `Tensor` or `SparseTensor` objects for each in `features`. @@ -786,6 +793,7 @@ def read_batch_features(file_pattern, reader_num_threads=reader_num_threads, feature_queue_capacity=feature_queue_capacity, num_enqueue_threads=num_enqueue_threads, + read_batch_size=read_batch_size, parse_fn=parse_fn, name=name) return features From 5c331cfd573984287778aab02794dd86ba1f3006 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Oct 2017 12:47:57 -0700 Subject: [PATCH 37/41] The new array class provides a way to simplify the implementation of these classes by eliminating a large number of duplicated code. Removing the old API is non-trivial because of the existing users outside of tensorflow. PiperOrigin-RevId: 172920837 --- .../compiler/xla/client/computation_builder.h | 41 +++--- tensorflow/compiler/xla/layout_util.cc | 4 + tensorflow/compiler/xla/layout_util.h | 1 + tensorflow/compiler/xla/literal_util.h | 121 ++++++++---------- 4 files changed, 85 insertions(+), 82 deletions(-) diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index cdd9c8847f5..93c2a806780 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -138,6 +138,11 @@ class ComputationBuilder { ComputationDataHandle ConstantR2( std::initializer_list> values); template + ComputationDataHandle ConstantFromArrayWithLayout( + const Array& values, const Layout& layout); + template + ComputationDataHandle ConstantFromArray(const Array& values); + template ComputationDataHandle ConstantR2FromArray2DWithLayout( const Array2D& values, const Layout& layout); template @@ -909,49 +914,55 @@ ComputationDataHandle ComputationBuilder::ConstantR2( [&values](Literal* literal) { literal->PopulateR2(values); }); } +template +ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout( + const Array& values, const Layout& layout) { + return ConstantOp([&values, &layout](Literal* literal) { + literal->PopulateFromArrayWithLayout(values, layout); + }); +} + +template +ComputationDataHandle ComputationBuilder::ConstantFromArray( + const Array& values) { + return ConstantOp( + [&values](Literal* literal) { literal->PopulateFromArray(values); }); +} + template ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( const Array2D& values, const Layout& layout) { - return ConstantOp([&values, &layout](Literal* literal) { - literal->PopulateR2FromArray2DWithLayout(values, layout); - }); + return ConstantFromArrayWithLayout(values, layout); } template ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D( const Array2D& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR2FromArray2D(values); }); + return ConstantFromArray(values); } template ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { - return ConstantOp([&values, &layout](Literal* literal) { - literal->PopulateR3FromArray3DWithLayout(values, layout); - }); + return ConstantFromArrayWithLayout(values, layout); } template ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D( const Array3D& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR3FromArray3D(values); }); + return ConstantFromArray(values); } template ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout( const Array4D& values, const Layout& layout) { - return ConstantOp([&values, &layout](Literal* literal) { - literal->PopulateR4FromArray4DWithLayout(values, layout); - }); + return ConstantFromArrayWithLayout(values, layout); } template ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D( const Array4D& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR4FromArray4D(values); }); + return ConstantFromArray(values); } } // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 011fc3c194e..5c2cc2a7a99 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -83,6 +83,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return CreateDefaultLayoutForRank(shape.dimensions_size()); } +/* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64 rank) { + return CreateDefaultLayoutForRank(rank); +} + /* static */ Layout LayoutUtil::GetDefaultLayoutForR2() { return CreateDefaultLayoutForRank(2); } diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 5de0a653f66..bc42e222292 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -40,6 +40,7 @@ class LayoutUtil { static Layout GetDefaultLayoutForShape(const Shape& shape); // Helper functions that create default layouts for various ranks. + static Layout GetDefaultLayoutForRank(int64 rank); static Layout GetDefaultLayoutForR2(); static Layout GetDefaultLayoutForR3(); static Layout GetDefaultLayoutForR4(); diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index e8cee732d4c..4063cb05a91 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -334,6 +334,11 @@ class Literal { // WithLayout use the default XLA layout for the literal's linear // representation in memory. template + static std::unique_ptr CreateFromArray(const Array& values); + template + static std::unique_ptr CreateFromArrayWithLayout( + const Array& values, const Layout& layout); + template static std::unique_ptr CreateR2FromArray2D( const Array2D& values); template @@ -481,6 +486,11 @@ class Literal { std::initializer_list> values, const Layout& layout); template + void PopulateFromArray(const Array& values); + template + void PopulateFromArrayWithLayout(const Array& values, + const Layout& layout); + template void PopulateR2FromArray2D(const Array2D& values); template void PopulateR2FromArray2DWithLayout(const Array2D& values, @@ -815,34 +825,43 @@ template return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4()); } +template +/* static */ std::unique_ptr Literal::CreateFromArrayWithLayout( + const Array& values, const Layout& layout) { + auto literal = MakeUnique(); + literal->PopulateFromArrayWithLayout(values, layout); + return literal; +} + +template +/* static */ std::unique_ptr Literal::CreateFromArray( + const Array& values) { + return CreateFromArrayWithLayout( + values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); +} + template /* static */ std::unique_ptr Literal::CreateR2FromArray2DWithLayout( const Array2D& values, const Layout& layout) { - auto literal = MakeUnique(); - literal->PopulateR2FromArray2DWithLayout(values, layout); - return literal; + return CreateFromArrayWithLayout(values, layout); } template /* static */ std::unique_ptr Literal::CreateR2FromArray2D( const Array2D& values) { - return CreateR2FromArray2DWithLayout(values, - LayoutUtil::GetDefaultLayoutForR2()); + return CreateFromArray(values); } template /* static */ std::unique_ptr Literal::CreateR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { - auto literal = MakeUnique(); - literal->PopulateR3FromArray3DWithLayout(values, layout); - return literal; + return CreateFromArrayWithLayout(values, layout); } template /* static */ std::unique_ptr Literal::CreateR3FromArray3D( const Array3D& values) { - return CreateR3FromArray3DWithLayout(values, - LayoutUtil::GetDefaultLayoutForR3()); + return CreateFromArray(values); } template @@ -901,16 +920,13 @@ template template /* static */ std::unique_ptr Literal::CreateR4FromArray4D( const Array4D& values) { - return CreateR4FromArray4DWithLayout(values, - LayoutUtil::GetDefaultLayoutForR4()); + return CreateFromArray(values); } template /* static */ std::unique_ptr Literal::CreateR4FromArray4DWithLayout( const Array4D& values, const Layout& layout) { - auto literal = MakeUnique(); - literal->PopulateR4FromArray4DWithLayout(values, layout); - return literal; + return CreateFromArrayWithLayout(values, layout); } template @@ -1069,83 +1085,54 @@ void Literal::PopulateR2( PopulateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } +template +void Literal::PopulateFromArrayWithLayout(const Array& values, + const Layout& layout) { + *mutable_shape() = ShapeUtil::MakeShapeWithLayout( + primitive_util::NativeToPrimitiveType(), values.dimensions(), + AsInt64Slice(layout.minor_to_major())); + Reserve(values.num_elements()); + values.Each([this](tensorflow::gtl::ArraySlice indices, + NativeT value) { this->Set(indices, value); }); +} + +template +void Literal::PopulateFromArray(const Array& values) { + PopulateFromArrayWithLayout( + values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); +} + template void Literal::PopulateR2FromArray2DWithLayout(const Array2D& values, const Layout& layout) { - *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType(), - {values.height(), values.width()}, AsInt64Slice(layout.minor_to_major())); - - const int64 dim1_size = values.width(); - const int64 dim0_size = values.height(); - CHECK_EQ(dim0_size, shape().dimensions(0)); - CHECK_EQ(dim1_size, shape().dimensions(1)); - Reserve(dim1_size * dim0_size); - for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) { - for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) { - Set({dim0, dim1}, values(dim0, dim1)); - } - } + PopulateFromArrayWithLayout(values, layout); } template void Literal::PopulateR2FromArray2D(const Array2D& values) { - PopulateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); + PopulateFromArray(values); } template void Literal::PopulateR3FromArray3DWithLayout(const Array3D& values, const Layout& layout) { - *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType(), - {values.n1(), values.n2(), values.n3()}, - AsInt64Slice(layout.minor_to_major())); - - CHECK_EQ(values.n1(), shape().dimensions(0)); - CHECK_EQ(values.n2(), shape().dimensions(1)); - CHECK_EQ(values.n3(), shape().dimensions(2)); - Reserve(values.n1() * values.n2() * values.n3()); - for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) { - for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) { - for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) { - Set({dim0, dim1, dim2}, values(dim0, dim1, dim2)); - } - } - } + PopulateFromArrayWithLayout(values, layout); } template void Literal::PopulateR3FromArray3D(const Array3D& values) { - PopulateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); + PopulateFromArray(values); } template void Literal::PopulateR4FromArray4DWithLayout(const Array4D& values, const Layout& layout) { - *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType(), - {values.planes(), values.depth(), values.height(), values.width()}, - AsInt64Slice(layout.minor_to_major())); - - CHECK_EQ(values.n1(), shape().dimensions(0)); - CHECK_EQ(values.n2(), shape().dimensions(1)); - CHECK_EQ(values.n3(), shape().dimensions(2)); - CHECK_EQ(values.n4(), shape().dimensions(3)); - Reserve(values.n1() * values.n2() * values.n3() * values.n4()); - for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) { - for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) { - for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) { - for (int64 dim3 = 0; dim3 < values.n4(); ++dim3) { - Set({dim0, dim1, dim2, dim3}, values(dim0, dim1, dim2, dim3)); - } - } - } - } + PopulateFromArrayWithLayout(values, layout); } template void Literal::PopulateR4FromArray4D(const Array4D& values) { - PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4()); + PopulateFromArray(values); } template From 5bb971864220e0afdb5587680f444d3779a0f2cf Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Fri, 20 Oct 2017 12:56:51 -0700 Subject: [PATCH 38/41] TFE: Raises an error when attempting to save multiple ResourceVariable objects with the same shared_name. The only way to get multiple objects is if they're created in different Graphs/IsolateTest contexts. Previously this snuck by because of a list -> dictionary conversion without key checking. Allows the same object to be passed multiple times (so people don't need to de-duplicate their lists). PiperOrigin-RevId: 172921932 --- tensorflow/contrib/eager/python/saver_test.py | 34 +++++++++++++++++++ tensorflow/python/training/saver.py | 9 ++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py index c89554e6dd0..1605435d8d7 100644 --- a/tensorflow/contrib/eager/python/saver_test.py +++ b/tensorflow/contrib/eager/python/saver_test.py @@ -54,6 +54,40 @@ class SaverTest(test.TestCase): saver.restore(ckpt_prefix) self.assertEqual(v1.read_value().numpy(), 1.0) + def testSameNameNoClobbering(self): + with context.eager_mode(), ops.device(self._dev()): + # Note that this test purposefully uses Graphs rather than + # IsolateTest. Users are more likely to accidentally create the same + # variable name this way. + first_graph = ops.Graph() + with first_graph.as_default(): + v1_first_graph = resource_variable_ops.ResourceVariable(1.0, name='v1') + with ops.Graph().as_default(): + v1_second_graph = resource_variable_ops.ResourceVariable(2.0, name='v1') + saver = _saver.Saver([v1_first_graph, v1_second_graph]) + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + with self.assertRaisesRegexp(ValueError, 'v1'): + saver.save(ckpt_prefix) + + def testDifferentGraphError(self): + with context.eager_mode(), ops.device(self._dev()): + with ops.Graph().as_default(): + v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') + with ops.Graph().as_default(): + saver = _saver.Saver([v1]) + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + with self.assertRaisesRegexp(ValueError, 'Graph'): + saver.save(ckpt_prefix) + + def testSameObjectOK(self): + with context.eager_mode(), ops.device(self._dev()): + v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') + # While different objects with the same shared_name are not good, passing + # in the same object multiple times is fine. + saver = _saver.Saver([v1, v1]) + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + saver.save(ckpt_prefix) + def testRestoreOnCreate(self): with ops.device(self._dev()): def model(init_val): diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index b1926f4eaf6..c4c1df22eb5 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -557,7 +557,14 @@ class BaseSaverBuilder(object): if not isinstance(var, resource_variable_ops.ResourceVariable): raise ValueError("Can only save/restore ResourceVariable eager " "mode is enabled, type: %s." % type(var)) - names_to_saveables[var._shared_name] = var + set_var = names_to_saveables.setdefault(var._shared_name, var) + if set_var is not var: + raise ValueError( + ("Two different ResourceVariable objects with the same " + "shared_name '%s' were passed to the Saver. This likely means " + "that they were created in different Graphs or isolation " + "contexts, and may not be checkpointed together.") % ( + var._shared_name,)) # pylint: enable=protected-access return names_to_saveables From 54503483ef987c6488d7bc2bd3c4b1d34fbd1f26 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Fri, 20 Oct 2017 13:01:41 -0700 Subject: [PATCH 39/41] Enables silent copies of eager tensors for specially-constructed contexts. PiperOrigin-RevId: 172922467 --- tensorflow/c/eager/BUILD | 3 +- tensorflow/c/eager/c_api.cc | 64 +++++++++++++++++++++++++---- tensorflow/c/eager/c_api.h | 16 ++++++++ tensorflow/c/eager/c_api_internal.h | 3 ++ tensorflow/c/eager/c_api_test.cc | 46 +++++++++++++++++++++ 5 files changed, 122 insertions(+), 10 deletions(-) diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 96f3c3e195e..c77896b80b4 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", + "tf_cuda_cc_test", "tf_cc_test", "tf_copts", "tf_cuda_library", @@ -50,7 +51,7 @@ tf_cuda_library( ], ) -tf_cc_test( +tf_cuda_cc_test( name = "c_api_test", srcs = ["c_api_test.cc"], deps = [ diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 334c02bff9a..28ea2edee4f 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -61,6 +61,11 @@ void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto, TF_SetConfig(&options->session_options, proto, proto_len, status); } +void TFE_ContextOptionsSetDevicePlacementPolicy( + TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) { + options->policy = policy; +} + void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { @@ -80,6 +85,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { } TFE_Context* ret = new TFE_Context(session); + ret->policy = opts->policy; ret->pflr.reset(new tensorflow::ProcessFunctionLibraryRuntime( ret->session->device_mgr, opts->session_options.options.env, TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {})); @@ -417,8 +423,10 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, namespace { tensorflow::Status ValidateInputTypeAndPlacement( - tensorflow::Device* host_device, tensorflow::Device* op_device, TFE_Op* op, - const tensorflow::OpKernel* kernel) { + TFE_Context* ctx, tensorflow::Device* host_device, + tensorflow::Device* op_device, TFE_Op* op, + const tensorflow::OpKernel* kernel, + std::vector* copied_tensors) { const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types(); if (memtypes.size() != op->inputs.size()) { return tensorflow::errors::InvalidArgument( @@ -430,11 +438,42 @@ tensorflow::Status ValidateInputTypeAndPlacement( const tensorflow::Device* actual_device = op->input_devices[i] == nullptr ? host_device : op->input_devices[i]; if (expected_device != actual_device) { - return tensorflow::errors::InvalidArgument( - "cannot compute ", op->name, " as input #", i, - " was expected to be on ", expected_device->name(), - " but is actually on ", actual_device->name(), - " (operation running on ", op_device->name(), ")"); + switch (ctx->policy) { + case TFE_DEVICE_PLACEMENT_EXPLICIT: + return tensorflow::errors::InvalidArgument( + "cannot compute ", op->name, " as input #", i, + " was expected to be on ", expected_device->name(), + " but is actually on ", actual_device->name(), + " (operation running on ", op_device->name(), ")"); + case TFE_DEVICE_PLACEMENT_WARN: + LOG(WARNING) << "before computing " << op->name << " input #" << i + << " was expected to be on " << expected_device->name() + << " but is actually on " << actual_device->name() + << " (operation running on " << op_device->name() + << "). This triggers a copy which can be a performance " + "bottleneck."; + break; + case TFE_DEVICE_PLACEMENT_SILENT: // Do nothing. + break; + } + // We are only here if the policy is warn or silent copies, so we should + // trigger a copy. + TFE_TensorHandle original{op->inputs[i], op->input_devices[i]}; + TF_Status* s = TF_NewStatus(); + TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice( + &original, ctx, expected_device->name().c_str(), s); + if (!s->status.ok()) { + tensorflow::Status status = s->status; + delete s; + return tensorflow::errors::Internal( + "Failed copying input tensor from ", actual_device->name(), " to ", + expected_device->name(), " in order to run ", op->name, ": ", + status.error_message()); + } + op->inputs[i] = copied_tensor->t; + copied_tensors->push_back(copied_tensor); + op->input_devices[i] = copied_tensor->d; + delete s; } if (op->inputs[i].dtype() != kernel->input_type(i)) { return tensorflow::errors::InvalidArgument( @@ -477,10 +516,14 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, } tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); } - status->status = ValidateInputTypeAndPlacement(ctx->devices()[0], device, op, - kernel->kernel()); + std::vector copied_tensors; + status->status = ValidateInputTypeAndPlacement( + ctx, ctx->devices()[0], device, op, kernel->kernel(), &copied_tensors); output_memory_types = &kernel->kernel()->output_memory_types(); if (!status->status.ok()) { + for (auto* t : copied_tensors) { + TFE_DeleteTensorHandle(t); + } return; } // WARNING: kernel->Run utilizes the FunctionLibraryRuntime @@ -492,6 +535,9 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, // sense for FunctionLibraryRuntime to ensure thread-safe access to // FunctionLibraryDefinition?). status->status = kernel->Run(&op->inputs, &outputs); + for (auto* t : copied_tensors) { + TFE_DeleteTensorHandle(t); + } if (!status->status.ok()) return; *num_retvals = std::min(*num_retvals, outputs.size()); for (int i = 0; i < *num_retvals; ++i) { diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 201cb222c92..865580c5f3a 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -56,6 +56,22 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetConfig( TFE_ContextOptions* options, const void* proto, size_t proto_len, TF_Status* status); +// Controls how to act when we try to run an operation on a given device but +// some input tensors are not on that device. +typedef enum TFE_ContextDevicePlacementPolicy { + // The default: running operations with input tensors on the wrong device will + // fail. + TFE_DEVICE_PLACEMENT_EXPLICIT = 0, + // Copy the tensor to the right device but log a warning. + TFE_DEVICE_PLACEMENT_WARN = 1, + // Silently copy the tensor, which has a performance cost since the + // operation will be blocked till the copy completes. + TFE_DEVICE_PLACEMENT_SILENT = 2, +} TFE_ContextDevicePlacementPolicy; + +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( + TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); + // Destroy an options object. TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 7a440a5a7e8..0971e2ab2fe 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -37,11 +37,14 @@ limitations under the License. struct TFE_ContextOptions { TF_SessionOptions session_options; + TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_EXPLICIT}; }; struct TFE_Context { explicit TFE_Context(TF_Session* s) : session(s) {} + TFE_ContextDevicePlacementPolicy policy; + // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph. TF_Session* session; tensorflow::Rendezvous* rendezvous; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 5344956ee77..4af91b8853d 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -216,6 +216,52 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) { EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } +TEST(CAPI, TensorHandleSilentCopy) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + const int num_devices = TF_DeviceListCount(devices); + + // Disable the test if no GPU is present. + if (num_devices > 1) { + const int device_to_use = 1; + const string name(TF_DeviceListName(devices, device_to_use, status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_TensorHandle* hgpu = + TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu); + TFE_OpSetDevice(matmul, name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteTensorHandle(hgpu); + } + + TF_DeleteDeviceList(devices); + TF_DeleteTensor(t); + TFE_DeleteTensorHandle(hcpu); + TFE_DeleteContext(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); +} + TEST(CAPI, Execute) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); From e65fbbc9dc608d97977b17e05250b015d65aa027 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Oct 2017 13:03:50 -0700 Subject: [PATCH 40/41] Expose tf.contrib.framework.current_arg_scope() PiperOrigin-RevId: 172922818 --- tensorflow/contrib/framework/__init__.py | 1 + tensorflow/contrib/framework/python/ops/arg_scope.py | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 2081a11f47d..8421ba7c042 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -37,6 +37,7 @@ See the @{$python/contrib.framework} guide. @@arg_scope @@add_arg_scope +@@current_arg_scope @@has_arg_scope @@arg_scoped_arguments diff --git a/tensorflow/contrib/framework/python/ops/arg_scope.py b/tensorflow/contrib/framework/python/ops/arg_scope.py index 9c194ec202a..2bce00fde24 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope.py @@ -67,6 +67,7 @@ from tensorflow.python.util import tf_decorator __all__ = ['arg_scope', 'add_arg_scope', + 'current_arg_scope', 'has_arg_scope', 'arg_scoped_arguments'] @@ -83,7 +84,7 @@ def _get_arg_stack(): return _ARGSTACK -def _current_arg_scope(): +def current_arg_scope(): stack = _get_arg_stack() return stack[-1] @@ -144,7 +145,7 @@ def arg_scope(list_ops_or_scope, **kwargs): raise TypeError('list_ops_or_scope must either be a list/tuple or reused' 'scope (i.e. dict)') try: - current_scope = _current_arg_scope().copy() + current_scope = current_arg_scope().copy() for op in list_ops_or_scope: key_op = _key_op(op) if not has_arg_scope(op): @@ -172,7 +173,7 @@ def add_arg_scope(func): A tuple with the decorated function func_with_args(). """ def func_with_args(*args, **kwargs): - current_scope = _current_arg_scope() + current_scope = current_arg_scope() current_args = kwargs key_func = _key_op(func) if key_func in current_scope: From d2d9a6c7cc3b4f8c068054082a0fa2f2b95bb3d6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Oct 2017 13:19:42 -0700 Subject: [PATCH 41/41] Add AdaptiveSharedBatchScheduler which processes batches at a variable rate which can be adjusted based on external feedback. For reasonable feedback, this scheduler should deliver better latency than the SharedBatchScheduler. PiperOrigin-RevId: 172924803 --- tensorflow/contrib/batching/BUILD | 22 + .../adaptive_shared_batch_scheduler.h | 463 ++++++++++++++++++ .../adaptive_shared_batch_scheduler_test.cc | 438 +++++++++++++++++ tensorflow/contrib/batching/batch_scheduler.h | 2 +- 4 files changed, 924 insertions(+), 1 deletion(-) create mode 100644 tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h create mode 100644 tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index 1555a3427fd..ae3f48f1b27 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -69,6 +69,28 @@ tf_cc_test( ], ) +cc_library( + name = "adaptive_shared_batch_scheduler", + hdrs = ["adaptive_shared_batch_scheduler.h"], + deps = [ + ":batch_scheduler", + "//tensorflow/contrib/batching/util:periodic_function_dynamic", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "adaptive_shared_batch_scheduler_test", + srcs = ["adaptive_shared_batch_scheduler_test.cc"], + deps = [ + ":adaptive_shared_batch_scheduler", + "//tensorflow/contrib/batching/test_util:fake_clock_env", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "basic_batch_scheduler", hdrs = ["basic_batch_scheduler.h"], diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h new file mode 100644 index 00000000000..ac32f096395 --- /dev/null +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h @@ -0,0 +1,463 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/batching/batch_scheduler.h" +#include "tensorflow/contrib/batching/util/periodic_function.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace serving { +namespace internal { +template +class ASBSBatch; + +template +class ASBSQueue; +} // namespace internal + +// Shared batch scheduler designed to minimize latency. The scheduler keeps +// track of a number of queues (one per model or model version) which are +// continuously enqueuing requests. The scheduler groups the requests into +// batches which it periodically sends off for processing (see +// shared_batch_scheduler.h for more details). The AdaptiveSharedBatchScheduler +// prioritizes batches by age (i.e. the batch's oldest request) irrespective of +// queue. The scheduler will process the oldest batch at an adjustable rate, +// regardless of batch size. The user can provide feedback to help set this rate +// to achieve some goal (i.e. minimize overall latency, limit cpu usage, etc). +// +// The rate (or rather, the corresponding period) is adjusted each time a batch +// is processed, using an exponentially weighted moving average to smooth +// potentially noisy feedback: +// ewma_feedback = ((N - 1) * ewma_feedback + feedback()) / N +// period *= (1 + K * emwa_feedback) +// +// Some potential use cases: +// Hardware Accelerators (GPUs & TPUs) - If some phase of batch processing +// involves serial processing by a device, from a latency perspective it is +// desirable to keep the device evenly loaded, avoiding the need to wait for +// the device to process prior batches. +// feedback = num_pending_on_device() - desired_pending. +// CPU utilization - If the batch processing is cpu dominated, you can reap +// latency gains when underutilized by increasing the processing rate, but +// back the rate off when the load increases to avoid overload. +// feedback = cpu_rate() - desired_cpu_rate. + +template +class AdaptiveSharedBatchScheduler + : public std::enable_shared_from_this< + AdaptiveSharedBatchScheduler> { + public: + struct Options { + // The name to use for the pool of batch threads. + string thread_pool_name = {"batch_threads"}; + // Number of batch processing threads; equivalently the maximum number of + // concurrently running batches. + int64 num_batch_threads = port::NumSchedulableCPUs(); + // The environment to use (typically only overridden by test code). + Env* env = Env::Default(); + // Initial batch scheduling period in microseconds. Will be altered for + // non-zero rate_feedback. + double initial_scheduling_period_micros = 500; + // Minimum batch scheduling period in microseconds. Recommend setting this + // value greater than 0, otherwise it may take a while to recover from a + // sustained time of negative scheduling_period_feedback (which may occur + // under low load). + double min_scheduling_period_micros = 100; + // Maximum batch scheduling period in microseconds. + double max_scheduling_period_micros = 10000; + // Feedback function used to modify the scheduling period each time a batch + // is scheduled. Should return values roughly O(1), with positive values + // resulting in an increased period. + std::function scheduling_period_feedback = [] { return 0.; }; + // To handle potentially noisy scheduling_period_feedback, the period is + // adjusted using an exponentially weighted moving average over the previous + // feedback_smoothing_batches batches. Must be greater than 0. + int64 feedback_smoothing_batches = 10; + }; + + // Ownership is shared between the caller of Create() and any queues created + // via AddQueue(). + static Status Create( + const Options& options, + std::shared_ptr>* scheduler); + + struct QueueOptions { + // Maximum size of each batch. + int max_batch_size = 1000; + // Maximum number of enqueued (i.e. non-scheduled) batches. + int max_enqueued_batches = 10; + }; + + using BatchProcessor = std::function>)>; + + // Adds queue (and its callback) to be managed by this scheduler. + Status AddQueue(const QueueOptions& options, + BatchProcessor process_batch_callback, + std::unique_ptr>* queue); + + private: + // access to AddBatch, RemoveQueue, GetEnv. + friend class internal::ASBSQueue; + + explicit AdaptiveSharedBatchScheduler(const Options& options); + + // Batch scheduling function which runs every scheduling_period_ microseconds. + void ProcessOneBatch(); + + // Notifies scheduler of non-empty batch which is eligible for processing. + void AddBatch(internal::ASBSBatch*); + + // Removes queue from scheduler. + void RemoveQueue(const internal::ASBSQueue* queue); + + Env* GetEnv() const { return options_.env; } + + const Options options_; + + struct BatchCompare { + bool operator()(const internal::ASBSBatch* a, + const internal::ASBSBatch* b); + }; + + // Collection of batches added by AddBatch, ordered by age. Owned by scheduler + // until they are released for processing. + std::priority_queue*, + std::vector*>, BatchCompare> + batches_ GUARDED_BY(mu_); + + // Unowned queues and callbacks added by AddQueue. + std::unordered_map*, BatchProcessor> + queues_and_callbacks_ GUARDED_BY(mu_); + + mutex mu_; + + // Responsible for running ProcessOneBatch. PeriodicFunction was used in order + // to check for deletion so that the thread can be shut down. + std::unique_ptr scheduling_thread_; + + // Responsible for running the batch processing callbacks. + std::unique_ptr batch_thread_pool_; + + // Time interval in microseconds between successive ProcessOneBatch calls. + double scheduling_period_; + + // Exponentially weighted moving average of + // options_.scheduling_period_feedback() evaluated in each ProcessOneBatch + // call. + double ewma_feedback_ = 0; + + TF_DISALLOW_COPY_AND_ASSIGN(AdaptiveSharedBatchScheduler); +}; + +////////////////////////////////////////////////////////// +// Implementation details follow. API users need not read. + +namespace internal { +// Consolidates tasks into batches, passing them off to the +// AdaptiveSharedBatchScheduler for processing. +template +class ASBSQueue : public BatchScheduler { + public: + using QueueOptions = + typename AdaptiveSharedBatchScheduler::QueueOptions; + + ASBSQueue(std::shared_ptr> scheduler, + const QueueOptions& options); + + ~ASBSQueue() override; + + // Adds task to current batch. Fails if the task size is larger than the batch + // size or if the current batch is full and this queue's number of outstanding + // batches is at its maximum. + Status Schedule(std::unique_ptr* task) override; + + // Number of tasks waiting to be scheduled. + size_t NumEnqueuedTasks() const override; + + // Number of size 1 tasks which could currently be scheduled without failing. + size_t SchedulingCapacity() const override; + + // Notifies queue that a batch is about to be scheduled; the queue should not + // place any more tasks in this batch. + void ReleaseBatch(const ASBSBatch* batch); + + private: + std::shared_ptr> scheduler_; + const QueueOptions options_; + // Owned by scheduler_. + ASBSBatch* current_batch_ GUARDED_BY(mu_) = nullptr; + int64 num_enqueued_batches_ GUARDED_BY(mu_) = 0; + int64 num_enqueued_tasks_ GUARDED_BY(mu_) = 0; + mutable mutex mu_; + TF_DISALLOW_COPY_AND_ASSIGN(ASBSQueue); +}; + +// Batch which remembers when and by whom it was created. +template +class ASBSBatch : public Batch { + public: + ASBSBatch(ASBSQueue* queue, int64 creation_time_micros) + : queue_(queue), creation_time_micros_(creation_time_micros) {} + + ~ASBSBatch() override {} + + ASBSQueue* queue() const { return queue_; } + + int64 creation_time_micros() const { return creation_time_micros_; } + + private: + ASBSQueue* queue_; + const int64 creation_time_micros_; + TF_DISALLOW_COPY_AND_ASSIGN(ASBSBatch); +}; +} // namespace internal + +// ---------------- AdaptiveSharedBatchScheduler ---------------- + +template +Status AdaptiveSharedBatchScheduler::Create( + const Options& options, + std::shared_ptr>* scheduler) { + if (options.num_batch_threads < 1) { + return errors::InvalidArgument("num_batch_threads must be positive; was ", + options.num_batch_threads); + } + if (options.min_scheduling_period_micros < 0) { + return errors::InvalidArgument( + "min_scheduling_period_micros must be >= 0; was ", + options.min_scheduling_period_micros); + } + if (options.min_scheduling_period_micros > + options.initial_scheduling_period_micros) { + return errors::InvalidArgument( + "initial_scheduling_period_micros (", + options.initial_scheduling_period_micros, + ") must be >= min_scheduling_period_micros (", + options.min_scheduling_period_micros, ")"); + } + if (options.initial_scheduling_period_micros > + options.max_scheduling_period_micros) { + return errors::InvalidArgument( + "initial_scheduling_period_micros (", + options.initial_scheduling_period_micros, + ") must be <= max_scheduling_period_micros (", + options.max_scheduling_period_micros, ")"); + } + if (options.feedback_smoothing_batches < 1) { + return errors::InvalidArgument( + "feedback_smoothing_batches must be positive; was ", + options.feedback_smoothing_batches); + } + scheduler->reset(new AdaptiveSharedBatchScheduler(options)); + return Status::OK(); +} + +template +AdaptiveSharedBatchScheduler::AdaptiveSharedBatchScheduler( + const Options& options) + : options_(options), + scheduling_period_(options.initial_scheduling_period_micros) { + PeriodicFunction::Options opts; + opts.thread_name_prefix = "scheduling_thread"; + opts.env = GetEnv(); + scheduling_thread_.reset( + new PeriodicFunction([this] { ProcessOneBatch(); }, 0, opts)); + batch_thread_pool_.reset(new thread::ThreadPool( + GetEnv(), options.thread_pool_name, options.num_batch_threads)); +} + +template +Status AdaptiveSharedBatchScheduler::AddQueue( + const QueueOptions& options, BatchProcessor process_batch_callback, + std::unique_ptr>* queue) { + if (options.max_batch_size <= 0) { + return errors::InvalidArgument("max_batch_size must be positive; was ", + options.max_batch_size); + } + if (options.max_enqueued_batches <= 0) { + return errors::InvalidArgument( + "max_enqueued_batches must be positive; was ", + options.max_enqueued_batches); + } + internal::ASBSQueue* asbs_queue_raw; + queue->reset(asbs_queue_raw = new internal::ASBSQueue( + this->shared_from_this(), options)); + mutex_lock l(mu_); + queues_and_callbacks_[asbs_queue_raw] = process_batch_callback; + return Status::OK(); +} + +template +void AdaptiveSharedBatchScheduler::AddBatch( + internal::ASBSBatch* batch) { + mutex_lock l(mu_); + batches_.push(batch); +} + +template +void AdaptiveSharedBatchScheduler::RemoveQueue( + const internal::ASBSQueue* queue) { + mutex_lock l(mu_); + queues_and_callbacks_.erase(queue); +} + +template +void AdaptiveSharedBatchScheduler::ProcessOneBatch() { + static const double kFeedbackMultiplier = .001; + internal::ASBSBatch* batch = nullptr; + BatchProcessor callback; + const int64 start_time_micros = GetEnv()->NowMicros(); + { + mutex_lock l(mu_); + if (!batches_.empty()) { + batch = batches_.top(); + batches_.pop(); + callback = queues_and_callbacks_[batch->queue()]; + } + } + if (batch != nullptr) { + double feedback = options_.scheduling_period_feedback(); + const int64 N = options_.feedback_smoothing_batches; + ewma_feedback_ = ((N - 1) * ewma_feedback_ + feedback) / N; + scheduling_period_ *= (1 + kFeedbackMultiplier * ewma_feedback_); + if (scheduling_period_ < options_.min_scheduling_period_micros) { + scheduling_period_ = options_.min_scheduling_period_micros; + } else if (scheduling_period_ > options_.max_scheduling_period_micros) { + scheduling_period_ = options_.max_scheduling_period_micros; + } + // Queue may destroy itself after ReleaseBatch is called. + batch->queue()->ReleaseBatch(batch); + batch_thread_pool_->Schedule([callback, batch] { + callback(std::unique_ptr>(batch)); + }); + } + const int64 sleep_time = + scheduling_period_ - (GetEnv()->NowMicros() - start_time_micros); + if (sleep_time > 0) { + GetEnv()->SleepForMicroseconds(sleep_time); + } +} + +template +bool AdaptiveSharedBatchScheduler::BatchCompare::operator()( + const internal::ASBSBatch* a, + const internal::ASBSBatch* b) { + return a->creation_time_micros() > b->creation_time_micros(); +} + +// ---------------- ASBSQueue ---------------- + +namespace internal { +template +ASBSQueue::ASBSQueue( + std::shared_ptr> scheduler, + const QueueOptions& options) + : scheduler_(scheduler), options_(options) {} + +template +ASBSQueue::~ASBSQueue() { + // Wait until last batch has been scheduled. + const int kSleepMicros = 1000; + for (;;) { + { + mutex_lock l(mu_); + if (num_enqueued_batches_ == 0) { + break; + } + } + scheduler_->GetEnv()->SleepForMicroseconds(kSleepMicros); + } + scheduler_->RemoveQueue(this); +} + +template +Status ASBSQueue::Schedule(std::unique_ptr* task) { + bool added_new_batch = false; + size_t size = (*task)->size(); + if (size > options_.max_batch_size) { + return errors::InvalidArgument("Task size ", size, + " is larger than maximum batch size ", + options_.max_batch_size); + } + { + mutex_lock l(mu_); + // Current batch is full, create another if allowed. + if (current_batch_ && + current_batch_->size() + size > options_.max_batch_size) { + if (num_enqueued_batches_ >= options_.max_enqueued_batches) { + return errors::Unavailable("The batch scheduling queue is full"); + } + current_batch_->Close(); + current_batch_ = nullptr; + } + if (!current_batch_) { + added_new_batch = true; + num_enqueued_batches_++; + current_batch_ = + new ASBSBatch(this, scheduler_->GetEnv()->NowMicros()); + } + current_batch_->AddTask(std::move(*task)); + num_enqueued_tasks_++; + } + if (added_new_batch) scheduler_->AddBatch(current_batch_); + return Status::OK(); +} + +template +void ASBSQueue::ReleaseBatch(const ASBSBatch* batch) { + mutex_lock l(mu_); + num_enqueued_batches_--; + num_enqueued_tasks_ -= batch->num_tasks(); + if (batch == current_batch_) { + current_batch_->Close(); + current_batch_ = nullptr; + } +} + +template +size_t ASBSQueue::NumEnqueuedTasks() const { + mutex_lock l(mu_); + return num_enqueued_tasks_; +} + +template +size_t ASBSQueue::SchedulingCapacity() const { + mutex_lock l(mu_); + const int current_batch_capacity = + current_batch_ ? options_.max_batch_size - current_batch_->size() : 0; + const int spare_batches = + options_.max_enqueued_batches - num_enqueued_batches_; + return spare_batches * options_.max_batch_size + current_batch_capacity; +} +} // namespace internal +} // namespace serving +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc new file mode 100644 index 00000000000..a07cd6d834f --- /dev/null +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc @@ -0,0 +1,438 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h" + +#include "tensorflow/contrib/batching/test_util/fake_clock_env.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace serving { +namespace anonymous { + +class FakeTask : public BatchTask { + public: + explicit FakeTask(size_t size) : size_(size) {} + + ~FakeTask() override = default; + + size_t size() const override { return size_; } + + private: + const size_t size_; + + TF_DISALLOW_COPY_AND_ASSIGN(FakeTask); +}; + +// Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()' on +// that task. Returns the resulting status. +Status ScheduleTask(size_t task_size, BatchScheduler* scheduler) { + std::unique_ptr task(new FakeTask(task_size)); + Status status = scheduler->Schedule(&task); + // Schedule() should have consumed 'task' iff it returned Status::OK. + CHECK_EQ(status.ok(), task == nullptr); + return status; +} + +// Creates a thread that waits on 'start' and then advances the fake clock in +// 'env' in a loop until 'stop' is notified. Useful for allowing objects that +// use the clock to be destroyed. +std::unique_ptr CreateFakeClockAdvancerThread( + test_util::FakeClockEnv* env, Notification* start, Notification* stop) { + return std::unique_ptr(Env::Default()->StartThread( + {}, "FakeClockAdvancerThread", [env, start, stop] { + start->WaitForNotification(); + while (!stop->HasBeenNotified()) { + env->AdvanceByMicroseconds(10); + Env::Default()->SleepForMicroseconds(10); + } + })); +} + +TEST(AdaptiveSharedBatchSchedulerTest, Basic) { + for (const bool delete_scheduler_early : {false, true}) { + for (const bool delete_queue_1_early : {false, true}) { + int queue_0_tasks = 0; + auto queue_0_callback = + [&queue_0_tasks](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + for (int i = 0; i < batch->num_tasks(); i++) { + queue_0_tasks += batch->task(i).size(); + } + }; + int queue_1_tasks = 0; + auto queue_1_callback = + [&queue_1_tasks](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + for (int i = 0; i < batch->num_tasks(); i++) { + queue_1_tasks += batch->task(i).size(); + } + }; + { + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create({}, &scheduler)); + + // Create two queues. + std::unique_ptr> queue_0; + TF_ASSERT_OK(scheduler->AddQueue({}, queue_0_callback, &queue_0)); + std::unique_ptr> queue_1; + TF_ASSERT_OK(scheduler->AddQueue({}, queue_1_callback, &queue_1)); + + if (delete_scheduler_early) { + // Delete our copy of the scheduler. The queues should keep it alive + // under the covers. + scheduler = nullptr; + } + // Submit tasks to the two queues, and (optionally) remove the queues. + TF_ASSERT_OK(ScheduleTask(1, queue_0.get())); + TF_ASSERT_OK(ScheduleTask(2, queue_1.get())); + TF_ASSERT_OK(ScheduleTask(3, queue_0.get())); + TF_ASSERT_OK(ScheduleTask(4, queue_1.get())); + if (delete_queue_1_early) { + queue_1 = nullptr; + } + TF_ASSERT_OK(ScheduleTask(5, queue_0.get())); + } + EXPECT_EQ(queue_0_tasks, 9); + EXPECT_EQ(queue_1_tasks, 6); + } + } +} + +TEST(AdaptiveSharedBatchSchedulerTest, BadOptions) { + using Scheduler = AdaptiveSharedBatchScheduler; + std::shared_ptr scheduler; + Scheduler::Options options; + options.num_batch_threads = 0; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.min_scheduling_period_micros = 50; + options.max_scheduling_period_micros = 100; + options.initial_scheduling_period_micros = 1; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.min_scheduling_period_micros = 50; + options.max_scheduling_period_micros = 100; + options.initial_scheduling_period_micros = 1000; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.min_scheduling_period_micros = 100; + options.max_scheduling_period_micros = 50; + options.initial_scheduling_period_micros = 75; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.feedback_smoothing_batches = 0; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); +} + +TEST(AdaptiveSharedBatchSchedulerTest, ObeysQueueOptions) { + test_util::FakeClockEnv env(Env::Default()); + Notification start_teardown, stop_teardown; + std::unique_ptr teardown_thread = + CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); + { + AdaptiveSharedBatchScheduler::Options options; + options.initial_scheduling_period_micros = 1000; + options.env = &env; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue_0; + std::unique_ptr> queue_1; + int queue_0_tasks = 0; + int queue_1_tasks = 0; + auto queue_0_callback = [&queue_0_tasks, + &env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + for (int i = 0; i < batch->num_tasks(); i++) { + queue_0_tasks += batch->task(i).size(); + } + env.SleepForMicroseconds(1); + }; + auto queue_1_callback = [&queue_1_tasks, + &env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + for (int i = 0; i < batch->num_tasks(); i++) { + queue_1_tasks += batch->task(i).size(); + } + env.SleepForMicroseconds(1); + }; + AdaptiveSharedBatchScheduler::QueueOptions queue_options; + queue_options.max_batch_size = 10; + queue_options.max_enqueued_batches = 0; + // Queue must have max_enqueued_batchs > 1. + EXPECT_FALSE( + scheduler->AddQueue(queue_options, queue_0_callback, &queue_0).ok()); + queue_options.max_enqueued_batches = 2; + TF_ASSERT_OK( + scheduler->AddQueue(queue_options, queue_0_callback, &queue_0)); + queue_options.max_batch_size = 0; + // Queue must have max_batch_size > 0. + EXPECT_FALSE( + scheduler->AddQueue(queue_options, queue_1_callback, &queue_1).ok()); + queue_options.max_batch_size = 2; + queue_options.max_enqueued_batches = 1; + TF_ASSERT_OK( + scheduler->AddQueue(queue_options, queue_1_callback, &queue_1)); + + // Wait for scheduling_thread to sleep. + env.BlockUntilThreadsAsleep(1); + // Task larger than max_batch_size shouldn't schedule. + EXPECT_FALSE(ScheduleTask(15, queue_0.get()).ok()); + TF_ASSERT_OK(ScheduleTask(5, queue_0.get())); + TF_ASSERT_OK(ScheduleTask(5, queue_0.get())); + env.AdvanceByMicroseconds(1); + + // Task larger than max_batch_size shouldn't schedule. + EXPECT_FALSE(ScheduleTask(3, queue_1.get()).ok()); + TF_ASSERT_OK(ScheduleTask(1, queue_1.get())); + TF_ASSERT_OK(ScheduleTask(1, queue_1.get())); + env.AdvanceByMicroseconds(1); + // Exceeds max_enqueued_batches, shouldn't schedule. + EXPECT_FALSE(ScheduleTask(1, queue_1.get()).ok()); + + TF_ASSERT_OK(ScheduleTask(5, queue_0.get())); + // Exceeds max_enqueued_batches, shouldn't schedule. + EXPECT_FALSE(ScheduleTask(6, queue_0.get()).ok()); + TF_ASSERT_OK(ScheduleTask(4, queue_0.get())); + + // Batches should be processed in order from oldest to newest. + env.AdvanceByMicroseconds(1000); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(queue_0_tasks, 10); + EXPECT_EQ(queue_1_tasks, 0); + + env.AdvanceByMicroseconds(1000); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(queue_0_tasks, 10); + EXPECT_EQ(queue_1_tasks, 2); + + env.AdvanceByMicroseconds(1000); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(queue_0_tasks, 19); + EXPECT_EQ(queue_1_tasks, 2); + start_teardown.Notify(); + } + stop_teardown.Notify(); +} + +TEST(AdaptiveSharedBatchSchedulerTest, RateFeedback) { + test_util::FakeClockEnv env(Env::Default()); + Notification start_teardown, stop_teardown; + std::unique_ptr teardown_thread = + CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); + { + double feedback = 0; + AdaptiveSharedBatchScheduler::Options options; + options.initial_scheduling_period_micros = 1000; + options.min_scheduling_period_micros = 200; + options.max_scheduling_period_micros = 2000; + options.env = &env; + options.scheduling_period_feedback = [&feedback] { return feedback; }; + options.feedback_smoothing_batches = 1; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue; + int scheduled_items = 0; + auto queue_callback = [&scheduled_items, + &env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + scheduled_items = 0; + for (int i = 0; i < batch->num_tasks(); i++) { + scheduled_items += batch->task(i).size(); + } + env.SleepForMicroseconds(1); + }; + + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); + + // Wait for scheduling_thread to sleep. + env.BlockUntilThreadsAsleep(1); + // Enqueue 6 batches. + for (int i = 0; i < 6; i++) { + TF_ASSERT_OK(ScheduleTask(900 + i, queue.get())); + env.AdvanceByMicroseconds(1); + } + feedback = -500; + env.AdvanceByMicroseconds(994); + env.BlockUntilThreadsAsleep(2); // scheduling period = 500 usec. + EXPECT_EQ(scheduled_items, 900); + env.AdvanceByMicroseconds(500); + env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec. + EXPECT_EQ(scheduled_items, 901); + feedback = 0; + env.AdvanceByMicroseconds(250); + env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec. + EXPECT_EQ(scheduled_items, 902); + feedback = 10000; // large feedback should hit max_scheduling_period. + env.AdvanceByMicroseconds(250); + env.BlockUntilThreadsAsleep(2); // scheduling period = 2000 usec. + EXPECT_EQ(scheduled_items, 903); + feedback = -10000; // large feedback should hit min_scheduling_period. + env.AdvanceByMicroseconds(1999); + // No callback scheduled, only scheduling thread sleeping. + env.BlockUntilThreadsAsleep(1); + EXPECT_EQ(scheduled_items, 903); + env.AdvanceByMicroseconds(1); + env.BlockUntilThreadsAsleep(2); // scheduling period = 200 usec. + EXPECT_EQ(scheduled_items, 904); + env.AdvanceByMicroseconds(200); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(scheduled_items, 905); + start_teardown.Notify(); + } + stop_teardown.Notify(); +} + +TEST(AdaptiveSharedBatchSchedulerTest, FeedbackSmoothing) { + test_util::FakeClockEnv env(Env::Default()); + Notification start_teardown, stop_teardown; + std::unique_ptr teardown_thread = + CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); + { + double feedback = 0; + AdaptiveSharedBatchScheduler::Options options; + options.initial_scheduling_period_micros = 1000; + options.env = &env; + options.scheduling_period_feedback = [&feedback] { return feedback; }; + options.feedback_smoothing_batches = 3; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue; + int scheduled_items = 0; + auto queue_callback = [&scheduled_items, + &env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + scheduled_items = 0; + for (int i = 0; i < batch->num_tasks(); i++) { + scheduled_items += batch->task(i).size(); + } + env.SleepForMicroseconds(1); + }; + + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); + + // Wait for scheduling_thread to sleep. + env.BlockUntilThreadsAsleep(1); + // Enqueue 4 batches. + for (int i = 0; i < 4; i++) { + TF_ASSERT_OK(ScheduleTask(900 + i, queue.get())); + env.AdvanceByMicroseconds(1); + } + feedback = -300; + env.AdvanceByMicroseconds(996); + env.BlockUntilThreadsAsleep(2); + // ewma_feedback = 100, scheduling_period = 900. + EXPECT_EQ(scheduled_items, 900); + env.AdvanceByMicroseconds(899); + // No callback scheduled, only scheduling thread sleeping. + env.BlockUntilThreadsAsleep(1); + EXPECT_EQ(scheduled_items, 900); + env.AdvanceByMicroseconds(1); + env.BlockUntilThreadsAsleep(2); + // ewma_feedback = 167, scheduling_period = 750. + EXPECT_EQ(scheduled_items, 901); + env.AdvanceByMicroseconds(749); + // No callback scheduled, only scheduling thread sleeping. + env.BlockUntilThreadsAsleep(1); + EXPECT_EQ(scheduled_items, 901); + feedback = 1000 / 3.; + env.AdvanceByMicroseconds(1); + env.BlockUntilThreadsAsleep(2); + // emwa_feedback = 0, scheduling_period = 750. + EXPECT_EQ(scheduled_items, 902); + env.AdvanceByMicroseconds(749); + // No callback scheduled, only scheduling thread sleeping. + env.BlockUntilThreadsAsleep(1); + EXPECT_EQ(scheduled_items, 902); + env.AdvanceByMicroseconds(1); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(scheduled_items, 903); + start_teardown.Notify(); + } + stop_teardown.Notify(); +} + +TEST(AdaptiveSharedBatchSchedulerTest, QueueCapacityInfo) { + test_util::FakeClockEnv env(Env::Default()); + Notification start_teardown, stop_teardown; + std::unique_ptr teardown_thread = + CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); + { + AdaptiveSharedBatchScheduler::Options options; + options.initial_scheduling_period_micros = 1000; + options.env = &env; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue; + int scheduled_items = 0; + auto queue_callback = [&scheduled_items, + &env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + scheduled_items = 0; + for (int i = 0; i < batch->num_tasks(); i++) { + scheduled_items += batch->task(i).size(); + } + env.SleepForMicroseconds(1); + }; + AdaptiveSharedBatchScheduler::QueueOptions queue_options; + queue_options.max_batch_size = 10; + queue_options.max_enqueued_batches = 10; + TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_callback, &queue)); + + // Wait for scheduling_thread to sleep. + env.BlockUntilThreadsAsleep(1); + // Enqueue 3 tasks. + EXPECT_EQ(queue->NumEnqueuedTasks(), 0); + EXPECT_EQ(queue->SchedulingCapacity(), 100); + TF_ASSERT_OK(ScheduleTask(5, queue.get())); + EXPECT_EQ(queue->NumEnqueuedTasks(), 1); + EXPECT_EQ(queue->SchedulingCapacity(), 95); + env.AdvanceByMicroseconds(1); + TF_ASSERT_OK(ScheduleTask(6, queue.get())); + EXPECT_EQ(queue->NumEnqueuedTasks(), 2); + EXPECT_EQ(queue->SchedulingCapacity(), 84); + env.AdvanceByMicroseconds(1); + TF_ASSERT_OK(ScheduleTask(1, queue.get())); + EXPECT_EQ(queue->NumEnqueuedTasks(), 3); + EXPECT_EQ(queue->SchedulingCapacity(), 83); + + env.AdvanceByMicroseconds(998); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(scheduled_items, 5); + env.AdvanceByMicroseconds(1000); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(scheduled_items, 7); + start_teardown.Notify(); + } + stop_teardown.Notify(); +} +} // namespace anonymous +} // namespace serving +} // namespace tensorflow diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/batching/batch_scheduler.h index 7c41ad88180..a5072f439ab 100644 --- a/tensorflow/contrib/batching/batch_scheduler.h +++ b/tensorflow/contrib/batching/batch_scheduler.h @@ -78,7 +78,7 @@ template class Batch { public: Batch() = default; - ~Batch(); // Blocks until the batch is closed. + virtual ~Batch(); // Blocks until the batch is closed. // Appends 'task' to the batch. After calling AddTask(), the newly-added task // can be accessed via task(num_tasks()-1) or mutable_task(num_tasks()-1).