From 74770749840e1c823a50b743a50637afc3529e3c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Sat, 29 Apr 2017 12:03:51 -0800 Subject: [PATCH 01/38] [XLA] Make ReshapeMover account for broadcast operands, add VLOGging for debug. Change: 154637127 --- tensorflow/compiler/xla/service/BUILD | 2 + .../compiler/xla/service/hlo_instruction.cc | 12 +- .../compiler/xla/service/hlo_instruction.h | 5 +- .../compiler/xla/service/reshape_mover.cc | 274 +++++++++++------- .../xla/service/reshape_mover_test.cc | 52 ++++ 5 files changed, 228 insertions(+), 117 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 21378887266..05fc480936f 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -859,7 +859,9 @@ cc_library( ":hlo_pass", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index d8e01f88b9f..179e1832654 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -410,7 +410,9 @@ HloInstruction::CreateSelectAndScatter( /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape( const Shape& shape, HloInstruction* operand) { CHECK_EQ(ShapeUtil::ElementsIn(shape), - ShapeUtil::ElementsIn(operand->shape())); + ShapeUtil::ElementsIn(operand->shape())) + << "shape: " << ShapeUtil::HumanString(shape) + << " operand: " << ShapeUtil::HumanString(operand->shape()); auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape)); instruction->AppendOperand(operand); return instruction; @@ -1428,7 +1430,8 @@ string HloInstruction::ExtendedOpcodeStr() const { return opc_name; } -string HloInstruction::ToString(bool compact_operands) const { +string HloInstruction::ToString(bool compact_operands, + bool include_metadata) const { string operands; if (opcode() == HloOpcode::kConstant) { // For constants, show the actual value in place of an empty operand list. @@ -1509,8 +1512,9 @@ string HloInstruction::ToString(bool compact_operands) const { if (opcode() == HloOpcode::kGetTupleElement) { StrAppend(&extra, ", index=", tuple_index()); } - if (!metadata_.op_type().empty() || !metadata_.op_name().empty() || - !metadata_.source_file().empty()) { + if (include_metadata && + (!metadata_.op_type().empty() || !metadata_.op_name().empty() || + !metadata_.source_file().empty())) { StrAppend(&extra, " # metadata=", metadata_.ShortDebugString()); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 1c7d4c19b97..5ec17c80048 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -489,7 +489,10 @@ class HloInstruction { string SignatureString() const; // Returns a debugging string that represents this instruction. - string ToString(bool compact_operands = false) const; + string ToString(bool compact_operands = false, + bool include_metadata = true) const; + + string ToStringNoMetadata() const { return ToString(false, false); } // As ToString, but returns a shorter string. string ToShortString() const; diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index b72ef95a6a7..768977ba6bb 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -13,17 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/reshape_mover.h" - -#include <algorithm> -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/util.h" - -namespace xla { - -namespace { - +// Implementation note: +// // The general idea behind this pass is that we're converting from this: // %param.A = OldShape // %param.B = OldShape @@ -44,6 +35,19 @@ namespace { // only implicit scalar broadcast is on Pred, not on A or B. Since reshapes or // transposes to a scalar should be cheap, we simply never move them. +#include "tensorflow/compiler/xla/service/reshape_mover.h" + +#include <algorithm> +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +namespace { + // Finds the first non-scalar operand of an instruction that is a reshape or // transpose and returns the operand if it is found or nullptr if not found. HloInstruction* FirstNonScalarReshapeOperand(const HloInstruction* hlo) { @@ -51,6 +55,9 @@ HloInstruction* FirstNonScalarReshapeOperand(const HloInstruction* hlo) { if (!ShapeUtil::IsScalar(operand->shape()) && (operand->opcode() == HloOpcode::kReshape || operand->opcode() == HloOpcode::kTranspose)) { + VLOG(5) << "Found first non-scalar reshape operand of " + << hlo->ToStringNoMetadata() << ":\n\t" + << operand->ToStringNoMetadata(); return operand; } } @@ -70,6 +77,9 @@ bool OperandCanTrivallyChangeShape(const HloInstruction* instruction, // A constant can trivially reshape the literal it holds. if (operand->opcode() == HloOpcode::kConstant && ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { + VLOG(5) << "Constant had same dimensions as instruction:\n\toperand: " + << operand->ToStringNoMetadata() + << "\n\tinstruction: " << instruction->ToStringNoMetadata(); return true; } @@ -116,119 +126,159 @@ bool IsElementwiseOfEquivalentReshapesOrTransposes( if (!first_reshape_operand) { return false; } - return (instruction->user_count() > 0 || - instruction == instruction->parent()->root_instruction()) && - instruction->IsElementwise() && !operands.empty() && - // Check whether all operands: - // 1. are all reshapes or transposes that have the same input and - // output shapes as all other reshaped or transposed operands. - // or - // 2. can be any shape like kConstant, kRng, and scalars. - std::all_of( - operands.begin(), operands.end(), - [instruction, - first_reshape_operand](const HloInstruction* operand) { - return AreEquivalentReshapes(first_reshape_operand, operand) || - OperandCanTrivallyChangeShape(instruction, operand); - }); + VLOG(3) << "** Checking whether instruction is an elementwise operation of " + "equivalent reshapes/transposes: " + << instruction->ToStringNoMetadata(); + bool result = + (instruction->user_count() > 0 || + instruction == instruction->parent()->root_instruction()) && + instruction->IsElementwise() && !operands.empty() && + // Check whether all operands: + // 0. Have the same dimensions as the output -- if not, it may be + // implicitly broadcast, which can confound the movement's + // correctness. + // 1. Are all reshapes or transposes that have the same input and + // output shapes as all other reshaped or transposed operands. + // or + // 2. Can be any shape like kConstant, kRng, and scalars. + std::all_of( + operands.begin(), operands.end(), + [instruction, first_reshape_operand](const HloInstruction* operand) { + if (!ShapeUtil::SameDimensions(operand->shape(), + instruction->shape())) { + VLOG(5) << "Operand shape differs from output shape; may be " + "implicitly broadcast, so preventing " + "movement\n\toperand: " + << operand->ToStringNoMetadata() << "\n\tinstruction: " + << instruction->ToStringNoMetadata(); + return false; + } + if (AreEquivalentReshapes(first_reshape_operand, operand)) { + VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " + << first_reshape_operand->ToStringNoMetadata() + << "\n\toperand: " << operand->ToStringNoMetadata(); + return true; + } + if (OperandCanTrivallyChangeShape(instruction, operand)) { + VLOG(5) << "Operand can trivially change shape: " + << operand->ToStringNoMetadata(); + return true; + } + return false; + }); + VLOG(3) << "ElementwiseOfEquivalentReshapesOrTransposes result for " + << instruction->ToStringNoMetadata() << ": " << result; + return result; } // Try to sink any reshape or transpose operands of `instruction` across it. We // do so if `instruction` is elementwise and all operands are equivalent // reshapes or transposes. -bool TrySinkReshapeOrTranspose(HloComputation* computation, - HloInstruction* instruction) { - if (IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) { - std::vector<HloInstruction*> operands = instruction->operands(); - HloInstruction* old_reshape = FirstNonScalarReshapeOperand(instruction); - CHECK(old_reshape != nullptr); - Shape new_elementwise_shape = old_reshape->operand(0)->shape(); - for (size_t i = 0; i < operands.size(); ++i) { - // All scalar operands remain as-is, even if they're reshape or transpose, - // to simplify handling wrt special scalar broadcast rules for ops like - // Select. Scalar reshapes should be cheap anyways. - if (ShapeUtil::IsScalar(operands[i]->shape())) { - continue; - } - auto element_type = operands[i]->shape().element_type(); - switch (operands[i]->opcode()) { - case HloOpcode::kConstant: { - if (old_reshape->opcode() == HloOpcode::kReshape) { - operands[i] = instruction->parent()->AddInstruction( - HloInstruction::CreateReshape( - ShapeUtil::ChangeElementType(new_elementwise_shape, - element_type), - operands[i])); - } else { - CHECK_EQ(old_reshape->opcode(), HloOpcode::kTranspose); - std::vector<int64> inverse_permutation = - InversePermutation(old_reshape->dimensions()); - operands[i] = instruction->parent()->AddInstruction( - HloInstruction::CreateTranspose( - ShapeUtil::ChangeElementType(new_elementwise_shape, - element_type), - operands[i], inverse_permutation)); - } - break; - } - case HloOpcode::kRng: { - CHECK_EQ(operands[i]->user_count(), 1); +StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation, + HloInstruction* instruction) { + if (!IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) { + return false; + } + + std::vector<HloInstruction*> operands = instruction->operands(); + HloInstruction* old_reshape = FirstNonScalarReshapeOperand(instruction); + TF_RET_CHECK(old_reshape != nullptr); + Shape new_elementwise_shape = old_reshape->operand(0)->shape(); + + VLOG(3) << "** Trying to sink reshape or transpose: " + << instruction->ToStringNoMetadata() + << "\n\told reshape: " << old_reshape->ToStringNoMetadata() + << "\n\tnew elementwise shape: " + << ShapeUtil::HumanString(new_elementwise_shape); + for (size_t i = 0; i < operands.size(); ++i) { + // All scalar operands remain as-is, even if they're reshape or transpose, + // to simplify handling wrt special scalar broadcast rules for ops like + // Select. Scalar reshapes should be cheap anyways. + if (ShapeUtil::IsScalar(operands[i]->shape())) { + continue; + } + PrimitiveType element_type = operands[i]->shape().element_type(); + switch (operands[i]->opcode()) { + case HloOpcode::kConstant: { + if (old_reshape->opcode() == HloOpcode::kReshape) { + VLOG(3) << "Creating reshape for kConstant operand " << i << ": " + << operands[i]->ToStringNoMetadata(); operands[i] = instruction->parent()->AddInstruction( - operands[i]->CloneWithNewOperands( + HloInstruction::CreateReshape( ShapeUtil::ChangeElementType(new_elementwise_shape, element_type), - operands[i]->operands())); - break; + operands[i])); + } else { + TF_RET_CHECK(old_reshape->opcode() == HloOpcode::kTranspose); + std::vector<int64> inverse_permutation = + InversePermutation(old_reshape->dimensions()); + operands[i] = instruction->parent()->AddInstruction( + HloInstruction::CreateTranspose( + ShapeUtil::ChangeElementType(new_elementwise_shape, + element_type), + operands[i], inverse_permutation)); } - case HloOpcode::kReshape: - case HloOpcode::kTranspose: - operands[i] = operands[i]->mutable_operand(0); - break; - default: - LOG(FATAL) << "Unexpected opcode while trying to sink reshapes or " - "transposes."; - } - } - if (HloOpcode::kFusion == instruction->opcode()) { - // Here we already know `instruction` is elementwise, and no operand is - // implicit broadcast as if it were the operands would not be equivalent - // reshapes, so all the fused instructions have the same dimensions. - for (const auto& fused_instruction : instruction->fused_instructions()) { - Shape* shape = fused_instruction->mutable_shape(); - *shape->mutable_dimensions() = new_elementwise_shape.dimensions(); - *shape->mutable_layout() = new_elementwise_shape.layout(); - } - } - auto new_elementwise = - computation->AddInstruction(instruction->CloneWithNewOperands( - // `instruction` may change the element type, e.g., from - // operands[0] -> reshape -> convert (`instruction`) - // to - // operands[0] -> convert' -> reshape' - // - // In this case, convert' should have the same element type as - // `convert` and the same dimensions as operands[0]. - ShapeUtil::ChangeElementType(new_elementwise_shape, - instruction->shape().element_type()), - operands)); - std::unique_ptr<HloInstruction> new_reshape; - switch (old_reshape->opcode()) { - case HloOpcode::kReshape: - new_reshape = HloInstruction::CreateReshape(instruction->shape(), - new_elementwise); break; + } + case HloOpcode::kRng: { + CHECK_EQ(operands[i]->user_count(), 1); + operands[i] = instruction->parent()->AddInstruction( + operands[i]->CloneWithNewOperands( + ShapeUtil::ChangeElementType(new_elementwise_shape, + element_type), + operands[i]->operands())); + break; + } + case HloOpcode::kReshape: case HloOpcode::kTranspose: - new_reshape = HloInstruction::CreateTranspose( - instruction->shape(), new_elementwise, old_reshape->dimensions()); + operands[i] = operands[i]->mutable_operand(0); break; default: - LOG(FATAL) << "Bad opcode"; + LOG(FATAL) << "Unexpected opcode while trying to sink reshapes or " + "transposes."; } - TF_CHECK_OK(computation->ReplaceWithNewInstruction(instruction, - std::move(new_reshape))); - return true; } - return false; + if (HloOpcode::kFusion == instruction->opcode()) { + // Here we already know `instruction` is elementwise, and no operand is + // implicit broadcast as if it were the operands would not be equivalent + // reshapes, so all the fused instructions have the same dimensions. + for (const auto& fused_instruction : instruction->fused_instructions()) { + Shape* shape = fused_instruction->mutable_shape(); + *shape->mutable_dimensions() = new_elementwise_shape.dimensions(); + *shape->mutable_layout() = new_elementwise_shape.layout(); + } + } + HloInstruction* new_elementwise = + computation->AddInstruction(instruction->CloneWithNewOperands( + // `instruction` may change the element type, e.g., from + // operands[0] -> reshape -> convert (`instruction`) + // to + // operands[0] -> convert' -> reshape' + // + // In this case, convert' should have the same element type as + // `convert` and the same dimensions as operands[0]. + ShapeUtil::ChangeElementType(new_elementwise_shape, + instruction->shape().element_type()), + operands)); + + std::unique_ptr<HloInstruction> new_reshape; + switch (old_reshape->opcode()) { + case HloOpcode::kReshape: + VLOG(3) << "Creating new reshape for new elementwise op: " + << new_elementwise->ToStringNoMetadata(); + new_reshape = + HloInstruction::CreateReshape(instruction->shape(), new_elementwise); + break; + case HloOpcode::kTranspose: + new_reshape = HloInstruction::CreateTranspose( + instruction->shape(), new_elementwise, old_reshape->dimensions()); + break; + default: + LOG(FATAL) << "Bad opcode"; + } + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + instruction, std::move(new_reshape))); + return true; } } // namespace @@ -237,9 +287,9 @@ StatusOr<bool> ReshapeMover::Run(HloModule* module) { bool changed = false; for (const auto& comp : module->computations()) { for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { - if (TrySinkReshapeOrTranspose(comp.get(), instruction)) { - changed = true; - } + TF_ASSIGN_OR_RETURN(bool did_change, + TrySinkReshapeOrTranspose(comp.get(), instruction)); + changed |= did_change; } } return changed; diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index 1831d775d4a..5217e85d4fc 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -234,6 +234,58 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { EXPECT_EQ(select, computation->root_instruction()); } +// Tree looks like: +// +// param0 [1,128,1] +// | +// reshape [128,1] constant [128,1024] +// \ / +// multiply w/implicit broadcast [128,1024] +// +// The reshape mover would like to sink the reshape below the multiply. +// +// Previously we would attempt to insert a reshape of the constant to [1,128,1] +// (which is unsound, because it has a different number of elements) as +// preparation for sinking the reshape. +// +// To eliminate the unsoundness, we outlaw reshape sinking when one of the +// operands is implicitly broadcast in the elementwise consumer. +// +// TODO(b/37799338) However, it would be possible in this case to do a more +// in-depth analysis to get reshape movement to occur: +// +// 1. Note that the broadcast dimension (logical dimension 1) in the operands +// would map back to logical dimension 2 in the param0 node. +// 2. Match rank of the constant to the param0 node (by prepending a trivial 1 +// dimension). +// 3. Reshape to [128,1024] at the root. +// +// But this is not currently done. +TEST_F(ReshapeMoverTest, ImplicitlyBroadcastReshapeIsNotMovedBug37787999) { + HloComputation::Builder builder(TestName()); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 128, 1}), "param0")); + auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {128, 1}), param0)); + Array2D<float> a(128, 1024); + auto literal = LiteralUtil::CreateR2FromArray2D<float>(a); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + auto multiply = builder.AddInstruction(HloInstruction::CreateBinary( + constant->shape(), HloOpcode::kMultiply, constant, reshape)); + + auto module = MakeUnique<HloModule>(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Constant(), op::Reshape(param0))); + + EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Constant(), op::Reshape(param0))); + EXPECT_EQ(multiply, computation->root_instruction()); +} + // Tree looks like this: // // add1 From dac34e886862ab92c7ee37e1abf8970608abde74 Mon Sep 17 00:00:00 2001 From: Toby Boyd <tobyboyd@google.com> Date: Sun, 30 Apr 2017 15:36:45 -0800 Subject: [PATCH 02/38] Details about performance models and techniques Change: 154679628 --- .../performance/performance_models.md | 389 ++++++++++++++++++ 1 file changed, 389 insertions(+) create mode 100644 tensorflow/docs_src/performance/performance_models.md diff --git a/tensorflow/docs_src/performance/performance_models.md b/tensorflow/docs_src/performance/performance_models.md new file mode 100644 index 00000000000..18f407c702e --- /dev/null +++ b/tensorflow/docs_src/performance/performance_models.md @@ -0,0 +1,389 @@ +# TensorFlow High-Performance Models + +TensorFlow is a powerful and flexible machine learning platform. +It can be used to distribute model training and inference across a large number +of machines and computation devices. + +Its software stack is made of a few layers: + +* a fast and powerful C++ core +* low-level Python primitives that sit right above individual kernels +* a diverse range of high-level libraries that aim to make building real models + easier + +There are many existing examples and tutorials that explain useful features in +TensorFlow. The goal of this set of scripts is to demonstrate that we can build +flexible and powerful high-performance models using the low-level APIs. +In the future, many of the high-performance primitives will be incorporated into +high-level APIs, and made available to more users transparently. +But meanwhile, we show that it is fairly easy for advanced users to build highly +scalable models targeting different system types, network topologies, etc. + +We divide our effort to build high-performance models into three categories: + +1. A fast input pipeline to read data from disk, preprocess it, and make it + ready on the GPU. +2. A high-throughput model that trains on GPU very efficiently. +3. Fast variable and gradients distribution mechanisms that scale well across + many machines and computation devices. + +## Input Pipeline + +The input pipeline is the part of a tensorflow program that reads input data, +shuffles it, and preprocesses it. + +Among the most important features to build a fast input pipeline: + +* Avoid using feed-dictionary to feed a large amount of data for each step. + * Instead, use reader ops to get data into TensorFlow directly. +* Parallelize data processing. +* Use software pipelining to feed data, so that data is available immediately + when needed. + +One way to implement software pipelining in TensorFlow is through +`tf.FifoQueue`, and it is possible to parallelize data processing through +`tf.train.queue_runner`, which uses Python threads as its underlying +implementation. +This lays the foundation for the current Inception input pipeline. +This design is well built for feeding older generation of GPUs, +but the overhead of Python threads is too large to feed newer GPUs that are four +to five times faster. + +In this model, we explore an alternative design that uses the native +parallelism in TensorFlow. In our example of an image model input pipeline, +there are a few important parts: + +* Choose and read the image files from the disk. +* Decode the image data into images, transform and add distortion so they are +ready to be used. +* Organize the transformed images into a minibatch. +* Transfer the images from CPU to GPU, so they are ready for model training. + +It is important to note that the dominant part of each stage can happen in +parallel with that of other stages: +the file IO uses DMA to transfer the data from hard disk to memory; +image decoding, transformation and distortion are CPU-heavy; +the data transfer from CPU to GPU uses the GPU's copy-engine unit; +and the GPU kernels use the main SMs of the GPU. +It is natural to cut our pipeline into those parts so they can run in parallel +with each other. + +Also, as mentioned earlier, most of the current input pipeline heavily uses +Python threads. However, the large overhead introduced by Python threads +severely limits its scalability when the newer GPUs are a lot faster; we can +alleviate this by making a single `session.run` call execute all parts of the +pipeline. + +### Parallelize IO Reads + +In this new model, we use the native parallelism in TensorFlow: TensorFlow +subscribes to an eager-execution model, which means that when nodes in the graph +became available, TensorFlow will try to execute as many of them as possible. + +In order to parallelize reading from hard disk, we use `data_flow_ops.RecordInput` +in this model. +Given a list of input files of TFRecords, `RecordInput` continuously reads +records using background threads, placing the records into its own large, +internal pool of records. +When it is has loaded at least half of its capacity, it produces output tensors. + +Since this op has its internal threads, and is dominated by IO time that doesn’t +consume much CPU time, it naturally runs in parallel with the rest of the model. + +### Parallelize Image Processing + +After reading from “RecordInput”, the tensors are passed to the input processing +pipeline. For example, if we need to feed 8 GPUs, each with a batch-size of 32, +then for each step we do the following. + +First, read 32x8=256 records, and process them individually, in +parallel. This starts with 256 independent RecordInput read ops in the graph. + +Then, follow each read with identical set of ops for processing. Each set is +considered independent and will execute in parallel. The operations include +image decoding, image distortion, and resizing. + +Finally, once the images are ready, they will be concatenated together into 8 +batch-size 32 tensors. +Note that we can use “tf.concat” for this purpose. +However, “tf.concat” is implemented as a single op, which waits for all +the inputs to be ready, and then concatenates them together. Since all +inputs are produced in parallel, there will be a long tail waiting for all +inputs to be available; and when concatenation happens, the op becomes memory +limited as all input tensors compete for memory bandwidth. +So for the final concatenation, we use `tf.parallel_stack` instead. This +allocates an uninitialized tensor as an output, and each input tensor is written +to its designated portion of the output tensor as soon as the input is +available. When all the input tensors are finished, the output tensor is passed +along in the graph. This effectively hides all the memory latency with the long +tail of producing all the input tensors. + +### Parallelize CPU-to-GPU Data Transfer + +In our example, once all the input images are processed and concatenated +together by the CPU, we have 8 tensors, each of which has a batch-size of 32. +These tensors are then to be used by the GPU for the model training. + +In TensorFlow, users can use tensors from one device on any other device +directly. TensorFlow inserts implicit copies to make the tensors available on +any devices where they are used. The runtime schedules the copy between devices +to run before the tensors are actually used. However, if the copy cannot finish +in time, the computation that needs those tensors will stall. + +For high-performance models, it is helpful to explicitly schedule the copy ahead +of the time in parallel, so when the computation starts on GPU, all the tensors +are already available on the right device. + +### Software Pipelining + +With all the stages capable of being driven by different processors, we insert +`data_flow_ops.StagingArea` in between them so they run in parallel. +`StagingArea` is a queue-like operator similar to `tf.FifoQueue`. +But it offers simpler functionalities and can be executed on both CPU and GPU. + +Before the model starts running all the stages, we warm up the stages in order +so the staging buffers in between all have one set of data in them. +During each run step that follows, we will run all the stages. +They read one set of data from the staging buffers at the beginning of each +stage, and push one set at end end. + +For example: if there are three stages: A, B and C. +There are two staging areas in between: S1 and S2. +During the warmup, we run: + +``` +Warm up: +Step 1: A0 +Step 2: A1 B0 + +Actual execution: +Step 3: A2 B1 C0 +Step 4: A3 B2 C1 +Step 5: A4 B3 C2 +``` + +After the warmup, S1 and S2 each have one set of data in them. +For each step of the actual execution, one set of data is consumed from each +staging area, and one set is added to each. + +There are a few nice properties about the scheme: + +* All the stages are non-blocking, since the staging areas always have one set +of data after the warmup. +* Each stage can run in parallel since they can all start immediately. +* The staging buffers have a fixed memory overhead. They will have at most one + extra set of data. +* Only a single`session.run()` call is needed to run all stages of the step, + which makes profiling and debugging much easier. + +## Best Practices in Building High-Performance Models + +The computation on GPU can happen immediately since the input data have already +been transferred onto GPU when the step starts. +But it is still important to build the model that runs as fast as possible. +Here are some tips for a high-performance convolutional neural network (CNN) +model: + +### Build the model with both NHWC and NCHW + +Most TensorFlow operations used by a CNN support both NHWC and NCHW data format. +On GPU, NCHW is faster. +But on CPU, NHWC is sometimes faster. + +So it is a good idea to build the model that can work in both ways. +Our model shows a good way to do that effectively. +For GPU training, we should always use NCHW. +But if the model needs inference on CPU, we could use NHWC; weights obtained +from training with NCHW data format can be used for inference in NHWC data +format. + +### Use Fused Batch-Normalization + +The default batch-normalization in TensorFlow is implemented as composite +operations. +This is very general, but often leads to suboptimal performance. +An alternative is the fused batch-normalization, and the performance on GPU +is often much faster. + +## Variable Distribution and Gradient Aggregation + +During training, training variable values are updated using aggregated gradients +and deltas. In this model, we demonstrate that with the flexible and +general-purpose TensorFlow primitives, it is fairly easy to build a diverse +range of high-performance distribution and aggregation schemes for different +types of systems. + +For example: + +* The standard parameter-server where each replica of the training model reads + the variables directly, and updates the variable independently. When each + model needs the variables, they are copied over through the standard implicit + copies added by the TensorFlow runtime. It is shown how to use this method + in either local training, distributed synchronous training, and distributed + asynchronous training. +* A replicated mode for local training where each GPU has an identical + copy of the training parameters. The forward and backward computation can + start immediately as the variable data is immediately available. Gradients + are accumulated across all GPUs, and the aggregated total is applied to + each GPU's copy of the variables so that they stay in sync. +* A distributed replicated mode of training where each GPU has an identical copy + of the training parameters, and a master copy of the variables is stored + on the parameter-servers. The forward and backward computation can + start immediately as the variable data is immediately available. Gradients + are accumulated across all GPUs on each server and then the per-server + aggregated gradients are applied to the master copy. After all workers do + this, each worker updates its copy of the variable from the master copy. + +We show that most of the variable distribution and aggregation subsystem can +be implemented through TensorFlow low-level primitives with manageable +complexity at the model level. Here we discuss some more details. + +### Parameter-server Variables + +The most common way trainable variables are managed in TensorFlow models is the +parameter server mode. + +In a distributed system, this means that each worker process runs the same +model, and parameter server processes own the master copies of the variables. +When a worker needs a variable from a parameter server, it refers to it +directly. The TensorFlow runtime adds implicit copies to the graph to make the +variable value available on the computation device that needs it. When a +gradient is computed on a worker, it is sent to the parameter server that owns +the particular variable, and the corresponding optimizer is used to update the +variable. + +There are some techniques to improve throughput: + +* The variables are spread among parameter servers based on their size, for load + balancing. +* When each worker has multiple GPUs, gradients are accumulated across the GPUs + and a single aggregated gradient is sent to the parameter server. This reduces + the network bandwidth and the amount of work done by the parameter servers. + +For coordinating between workers, a very common mode is async updates, where +each worker updates the master copy of the variables without synchronizing with +other workers. In our model, we demonstrate that it is fairly easy to introduce +synchronization across workers so updates for all workers are finished in one +step before the next step can start. + +The parameter-server method can also be used for local training, In this case, +instead of spreading the master copies of variables across parameters servers, +they are either on the CPU or spread across the available GPUs. + +Due to the simple nature of this setup, this architecture has gained a lot of +popularity within the community. + +This is available in the benchmark scripts as the 'parameter_server' +variable_update mode. + +{ +width="900" style="max-width: inherit"} + +### Replicated Variables + +In this design, each GPU on the server has its own copy of each variable. The +values are kept in sync across GPUs by applying the fully aggregated gradient to +each GPU's copy of the variable. + +The variables and data are available at the start of training, so the forward +pass of training can start immediately. Gradients are aggregated across the +devices and the fully aggregated gradient is then applied to each local copy. + +Gradient aggregation across the server can be done in different ways: + +* Using standard tensorflow operations to accumulate the total on a single + device (CPU or GPU) and then copy it back to all GPUs. +* Using NVIDIA NCCL, described below in the NCCL section. + +This is available in the benchmark scripts for local execution only, as the +'replicated' variable_update mode. + +### Replicated Variables in Distributed Training + +The replicated method for variables can be extended to distributed training. +One way to do this like the replicated mode: aggregate the gradients fully +across the cluster and apply them to each local copy of the variable. This may +be shown in a future version of this scripts; the scripts do present a different +variation, described here. + +In this mode, in addition to each GPU's copy of the variables, a master copy is +stored on the parameter servers. As with the replicated mode, training can start +immediately using the local copies of the variables. + +As the gradients of the weights become available, they are sent back to the +parameter servers and all local copies are updated: + +1. All the gradients from the GPU on the same worker are aggregated together. +2. Aggregated gradients from each worker are sent to the parameter server that + owns the variable, where the specified optimizer is used to update the + master copy of the variable. +3. Each worker updates its local copy of the variable from the master. In + the example model, this is done with a cross-replica barrier that waits for + all the workers to finish updating the variables, and fetches the new + variable only after the barrier has been released by all replicas. Once the + copy finishes for all variables, this marks the end of a training step, and a + new step can start. + +Although this sounds similar to the standard use of parameter servers, the +performance is often better in many cases. This is largely due to the fact the +computation can happen without any delay, and much of the copy latency of early +gradients can be hidden by later computation layers. + +This is available in the benchmark scripts as the 'distributed_replicated' +variable_update mode. + +{ +width="900" style="max-width: inherit"} + +#### NCCL + +In order to broadcast variables and aggregate gradients across different GPUs +within the same host machine, we can use the default TensorFlow implicit copy +mechanism. + +However, we can instead use the optional NCCL support. NCCL is an NVIDIA +library that can efficiently broadcast and aggregate data across different GPUs. +It schedules a cooperating kernel on each GPU that knows how to best utilize the +underlying hardware topology; this kernel uses a single SM of the GPU. + +In our experiment, we demonstrate that although NCCL often leads to much faster +data aggregation by itself, it doesn't necessarily lead to faster training. Our +hypothesis is that the implicit copies are essentially free since they go to the +copy engine on GPU, as long as its latency can be hidden by the main computation +itself. Although NCCL can transfer data faster, it takes one SM away, and adds +more pressure to the underlying L2 cache. Our results show that for 8-GPUs, +NCCL often leads to better performance. However, for fewer GPUs, the implicit +copies often perform better. + +#### Staged Variables + +We further introduce a staged-variable mode where we use staging areas for both +the variable reads, and their updates. +Similar to software pipelining of the input pipeline, this can hide the data +copy latency. +If the computation time takes longer than the copy and aggregation, the copy +itself becomes essentially free. + +The downside is that all the weights read are from the previous training step. +So it is a different algorithm from SGD. +But it is possible to improve its convergence by adjusting learning rate and +other hyperparameters. + +## Conclusions + +In this high-performance model, we present a number of options to build +high-performance models in TensorFlow. +Due to the flexible design in TensorFlow, advanced features like this often +requires no system-level changes, and can be largely achieved through +model-level changes. + +We do not claim which combination works best for a particular model. +That should be left to the engineers who build the model and the training system. +Many of the ingredients of the high-performance model will find their ways +to high-level primitives that become transparent to users. +However, we have shown that advanced users can easily tune and modify the +underlying model behavior using low-level primitives. +This could be very useful when improving performance for particular system +setups and model configurations. From e370aa6434ce3e17f8fa2a9df690500a8dc03d4e Mon Sep 17 00:00:00 2001 From: Justin Lebar <jlebar@google.com> Date: Sun, 30 Apr 2017 19:48:05 -0800 Subject: [PATCH 03/38] Add ReferenceUtil::ApplyElementwise2D. Change: 154687238 --- tensorflow/compiler/xla/reference_util.h | 39 ++++++++++++++++ .../compiler/xla/reference_util_test.cc | 44 ++++++++++++------- 2 files changed, 67 insertions(+), 16 deletions(-) diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index cdcad08c33e..03276121294 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -406,7 +406,46 @@ class ReferenceUtil { const PaddingConfig& padding, const float pad); + // ApplyElementwise2D(f, x, y, ...) returns the Array2D formed by running + // f(x[i], y[i], ...) for each array element in the Array2Ds x, y, .... + // + // The given arrays must have the same size and element type, and the return + // type of f must be implicitly convertible to the arrays' element type. + // + // Example usage: + // + // Array2D<float> x, y, z = ...; + // std::unique_ptr<Array2D> result = ReferenceUtil::ApplyElementwise2D( + // [](float a, float b, float c) { return a * b + c; }, x, y, z); + // + template <typename F, typename T1, typename... Ts> + static std::unique_ptr<Array2D<T1>> ApplyElementwise2D( + F&& f, const Array2D<T1>& array1, const Array2D<Ts>&... arrays) { + AssertSameSize2D(array1, arrays...); + auto result = MakeUnique<Array2D<T1>>(array1.n1(), array1.n1()); + for (int64 i = 0; i < array1.n1(); ++i) { + for (int64 j = 0; j < array1.n2(); ++j) { + (*result)(i, j) = f(array1(i, j), arrays(i, j)...); + } + } + return result; + } + private: + template <typename T1, typename T2, typename... Ts> + static void AssertSameSize2D(const Array2D<T1>& array1, + const Array2D<T2>& array2, + const Array2D<Ts>&... arrays) { + static_assert(std::is_same<T1, T2>::value, "Args must be same type."); + CHECK_EQ(array1.n1(), array2.n1()); + CHECK_EQ(array1.n2(), array2.n2()); + AssertSameSize2D(array2, arrays...); + } + + // Recursive base case for AssertSameSize2D. + template <typename Array1> + static void AssertSameSize2D(const Array1& array1) {} + TF_DISALLOW_COPY_AND_ASSIGN(ReferenceUtil); }; diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index b0aa5584028..f839ac019df 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -52,9 +52,9 @@ class ReferenceUtilTest : public ::testing::Test { TEST_F(ReferenceUtilTest, TransposeArray2D) { auto result = ReferenceUtil::TransposeArray2D(*matrix_); - auto result_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near<float>({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}}, - *result_literal, ErrorSpec(0.0001)); + *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MatmulArray2D) { @@ -62,32 +62,32 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) { {7.f, 8.f}, {9.f, 10.f}, {11.f, 12.f}, }); auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs); - auto result_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near<float>({{58.f, 64.f}, {139.f, 154.f}}, - *result_literal, ErrorSpec(0.0001)); + *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, ReduceToColArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add); - auto result_literal = LiteralUtil::CreateR1<float>(*result); - LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, *result_literal, + auto actual_literal = LiteralUtil::CreateR1<float>(*result); + LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add); - auto result_literal = LiteralUtil::CreateR1<float>(*result); - LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, *result_literal, + auto actual_literal = LiteralUtil::CreateR1<float>(*result); + LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MapArray2D) { auto identity = [](float value) { return log(exp(value)); }; auto result = ReferenceUtil::MapArray2D(*matrix_, identity); - auto result_literal = LiteralUtil::CreateR2FromArray2D(*result); - LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *result_literal, + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); + LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal, ErrorSpec(0.0001)); } @@ -96,9 +96,9 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) { return value + row + col; }; auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index); - auto result_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}}, - *result_literal, ErrorSpec(0.0001)); + *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MapArray4D) { @@ -107,11 +107,11 @@ TEST_F(ReferenceUtilTest, MapArray4D) { input->FillWithMultiples(1.0f); auto multiply_by_two = [](float value) { return 2 * value; }; auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two); - auto result_literal = LiteralUtil::CreateR4FromArray4D(*result); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.FillWithMultiples(2.0f); - LiteralTestUtil::ExpectR4NearArray4D(expected, *result_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); } @@ -124,11 +124,11 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { return value - (3 * 4 * 5 * plane + 4 * 5 * depth + 5 * height + width); }; auto result = ReferenceUtil::MapWithIndexArray4D(*input, subtract_index); - auto result_literal = LiteralUtil::CreateR4FromArray4D(*result); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.Fill(0.0f); - LiteralTestUtil::ExpectR4NearArray4D(expected, *result_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); } @@ -302,5 +302,17 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { ErrorSpec(0.0001)); } +TEST_F(ReferenceUtilTest, ApplyElementwise2D) { + Array2D<float> a({{1, 2}, {3, 4}}); + Array2D<float> b({{10, 20}, {30, 40}}); + Array2D<float> c({{100, 200}, {300, 400}}); + + auto actual = ReferenceUtil::ApplyElementwise2D( + [](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual); + LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}}, + *actual_literal, ErrorSpec(0.0001)); +} + } // namespace } // namespace xla From 838a607635769c0b851fe7859c0a2f02495b14d0 Mon Sep 17 00:00:00 2001 From: Mark Daoust <markdaoust@google.com> Date: Mon, 1 May 2017 04:18:19 -0800 Subject: [PATCH 04/38] Clarify docs on unused GraphKeys: WEIGHTS, BIASES, ACTIVTIONS Closes #7741 Change: 154708825 --- tensorflow/contrib/layers/README.md | 7 ++----- tensorflow/python/framework/ops.py | 10 +++++++--- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/tensorflow/contrib/layers/README.md b/tensorflow/contrib/layers/README.md index 7b374172f5c..9310b194dff 100644 --- a/tensorflow/contrib/layers/README.md +++ b/tensorflow/contrib/layers/README.md @@ -18,17 +18,14 @@ these arguments. …, weight_init=<DEFAULT>, bias_init=<DEFAULT>, - weight_collections=(tf.GraphKeys.WEIGHTS,), - bias_collections=(tf.GraphKeys.BIASES,), - output_collections=(tf.GraphKeys.ACTIVATIONS,), weight_regularizer=None, bias_regularizer=None, name=None) : Tensor` `x` is the input tensor. -Weights, biases, and activations (i.e., outputs) are, by default, added to the specified collections. Weights and biases are also added to -`tf.GraphKeys.GLOBAL_VARIABLES` and `tf.GraphKeys.TRAINABLE_VARIABLES`. +Weights and biases are added to `tf.GraphKeys.GLOBAL_VARIABLES` and +`tf.GraphKeys.TRAINABLE_VARIABLES`. ## optimizers.py diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 2a094203d7d..93a29d0d8e9 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -3981,9 +3981,13 @@ class GraphKeys(object): for more details. * `REGULARIZATION_LOSSES`: regularization losses collected during graph construction. - * `WEIGHTS`: weights inside neural network layers - * `BIASES`: biases inside neural network layers - * `ACTIVATIONS`: activations of neural network layers + + The following standard keys are _defined_, but their collections are **not** + automatically populated as many of the others are: + + * `WEIGHTS` + * `BIASES` + * `ACTIVATIONS` """ # Key to collect Variable objects that are global (shared across machines). From 9a68fb3812477751e5651477281d3eb003d5e509 Mon Sep 17 00:00:00 2001 From: Mark Daoust <markdaoust@google.com> Date: Mon, 1 May 2017 05:05:42 -0800 Subject: [PATCH 05/38] Docs fix for sdca_ops.cc Change: 154710766 --- tensorflow/core/ops/sdca_ops.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/ops/sdca_ops.cc b/tensorflow/core/ops/sdca_ops.cc index 2029ed7de22..b9a37119b74 100644 --- a/tensorflow/core/ops/sdca_ops.cc +++ b/tensorflow/core/ops/sdca_ops.cc @@ -72,17 +72,17 @@ optimizer applies each update one example at a time. Examples are sampled uniformly, and the optimizer is learning rate free and enjoys linear convergence rate. -Proximal Stochastic Dual Coordinate Ascent, Shalev-Shwartz, Shai; Zhang, Tong. -2012 arXiv1211.2717S: http://arxiv.org/pdf/1211.2717v1.pdf +[Proximal Stochastic Dual Coordinate Ascent](http://arxiv.org/pdf/1211.2717v1.pdf).<br> +Shai Shalev-Shwartz, Tong Zhang. 2012 - Loss objective = \sum f_{i}(wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w| +$$Loss Objective = \sum f_{i} (wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w|$$ -Adding vs. Averaging in Distributed Primal-Dual Optimization. -Chenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan, Peter Richtarik, -Martin Takac http://arxiv.org/abs/1502.03508 +[Adding vs. Averaging in Distributed Primal-Dual Optimization](http://arxiv.org/abs/1502.03508).<br> +Chenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan, +Peter Richtarik, Martin Takac. 2015 -Stochastic Dual Coordinate Ascent with Adaptive Probabilities -Dominik Csiba, Zheng Qu, Peter Richtarik https://arxiv.org/abs/1502.08053 +[Stochastic Dual Coordinate Ascent with Adaptive Probabilities](https://arxiv.org/abs/1502.08053).<br> +Dominik Csiba, Zheng Qu, Peter Richtarik. 2015 loss_type: Type of the primal loss. Currently SdcaSolver supports logistic, squared and hinge losses. From 6dd7b9792fc1f081725ec4840fb7e5f083bff7fe Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 1 May 2017 05:18:39 -0800 Subject: [PATCH 06/38] Update ops-related pbtxt files. Change: 154711274 --- tensorflow/core/ops/ops.pbtxt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 630f255d0f9..bd963a52669 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -19650,7 +19650,7 @@ op { minimum: 1 } summary: "Distributed version of Stochastic Dual Coordinate Ascent (SDCA) optimizer for" - description: "linear models with L1 + L2 regularization. As global optimization objective is\nstrongly-convex, the optimizer optimizes the dual objective at each step. The\noptimizer applies each update one example at a time. Examples are sampled\nuniformly, and the optimizer is learning rate free and enjoys linear convergence\nrate.\n\nProximal Stochastic Dual Coordinate Ascent, Shalev-Shwartz, Shai; Zhang, Tong.\n2012 arXiv1211.2717S: http://arxiv.org/pdf/1211.2717v1.pdf\n\n Loss objective = \\sum f_{i}(wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w|\n\nAdding vs. Averaging in Distributed Primal-Dual Optimization.\nChenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan, Peter Richtarik,\nMartin Takac http://arxiv.org/abs/1502.03508\n\nStochastic Dual Coordinate Ascent with Adaptive Probabilities\nDominik Csiba, Zheng Qu, Peter Richtarik https://arxiv.org/abs/1502.08053" + description: "linear models with L1 + L2 regularization. As global optimization objective is\nstrongly-convex, the optimizer optimizes the dual objective at each step. The\noptimizer applies each update one example at a time. Examples are sampled\nuniformly, and the optimizer is learning rate free and enjoys linear convergence\nrate.\n\n[Proximal Stochastic Dual Coordinate Ascent](http://arxiv.org/pdf/1211.2717v1.pdf).<br>\nShai Shalev-Shwartz, Tong Zhang. 2012\n\n$$Loss Objective = \\sum f_{i} (wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w|$$\n\n[Adding vs. Averaging in Distributed Primal-Dual Optimization](http://arxiv.org/abs/1502.03508).<br>\nChenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan,\nPeter Richtarik, Martin Takac. 2015\n\n[Stochastic Dual Coordinate Ascent with Adaptive Probabilities](https://arxiv.org/abs/1502.08053).<br>\nDominik Csiba, Zheng Qu, Peter Richtarik. 2015" } op { name: "SdcaShrinkL1" From 3e06162307e34a574ef8a77b5c552a543c8797d3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 1 May 2017 05:26:41 -0800 Subject: [PATCH 07/38] Go: Update generated wrapper functions for TensorFlow ops. Change: 154711571 --- tensorflow/go/op/wrappers.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index d4af263fc65..74f05c86382 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -7941,17 +7941,17 @@ func SdcaOptimizerAdaptative(value bool) SdcaOptimizerAttr { // uniformly, and the optimizer is learning rate free and enjoys linear convergence // rate. // -// Proximal Stochastic Dual Coordinate Ascent, Shalev-Shwartz, Shai; Zhang, Tong. -// 2012 arXiv1211.2717S: http://arxiv.org/pdf/1211.2717v1.pdf +// [Proximal Stochastic Dual Coordinate Ascent](http://arxiv.org/pdf/1211.2717v1.pdf).<br> +// Shai Shalev-Shwartz, Tong Zhang. 2012 // -// Loss objective = \sum f_{i}(wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w| +// $$Loss Objective = \sum f_{i} (wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w|$$ // -// Adding vs. Averaging in Distributed Primal-Dual Optimization. -// Chenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan, Peter Richtarik, -// Martin Takac http://arxiv.org/abs/1502.03508 +// [Adding vs. Averaging in Distributed Primal-Dual Optimization](http://arxiv.org/abs/1502.03508).<br> +// Chenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan, +// Peter Richtarik, Martin Takac. 2015 // -// Stochastic Dual Coordinate Ascent with Adaptive Probabilities -// Dominik Csiba, Zheng Qu, Peter Richtarik https://arxiv.org/abs/1502.08053 +// [Stochastic Dual Coordinate Ascent with Adaptive Probabilities](https://arxiv.org/abs/1502.08053).<br> +// Dominik Csiba, Zheng Qu, Peter Richtarik. 2015 // // Arguments: // sparse_example_indices: a list of vectors which contain example indices. From 2a1a0d729fb23ac8758fff64119503e8e3e7b94d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 1 May 2017 05:55:07 -0800 Subject: [PATCH 08/38] Removing unnecessary cuda:platform.bzl Change: 154712918 --- third_party/gpus/cuda/BUILD.tpl | 2 -- third_party/gpus/cuda_configure.bzl | 13 ------------- 2 files changed, 15 deletions(-) diff --git a/third_party/gpus/cuda/BUILD.tpl b/third_party/gpus/cuda/BUILD.tpl index 279c4412552..f7610dd7a99 100644 --- a/third_party/gpus/cuda/BUILD.tpl +++ b/third_party/gpus/cuda/BUILD.tpl @@ -1,7 +1,5 @@ licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like -load("@local_config_cuda//cuda:platform.bzl", "readlink_command") - package(default_visibility = ["//visibility:public"]) config_setting( diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 87d6ec1a3b0..0fef94cf9ca 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -732,12 +732,6 @@ def _create_dummy_repository(repository_ctx): "%{cuda_include_genrules}": '', "%{cuda_headers}": '', }) - _tpl(repository_ctx, "cuda:platform.bzl", - { - "%{cuda_version}": _DEFAULT_CUDA_VERSION, - "%{cudnn_version}": _DEFAULT_CUDNN_VERSION, - "%{platform}": cpu_value, - }) # Create dummy files for the CUDA toolkit since they are still required by # tensorflow/core/platform/default/build_config:cuda. @@ -913,13 +907,6 @@ def _create_cuda_repository(repository_ctx): "%{cuda_headers}": ('":cuda-include",\n' + ' ":cudnn-include",') }) - _tpl(repository_ctx, "cuda:platform.bzl", - { - "%{cuda_version}": cuda_config.cuda_version, - "%{cudnn_version}": cuda_config.cudnn_version, - "%{platform}": cuda_config.cpu_value, - }) - # Set up crosstool/ _file(repository_ctx, "crosstool:BUILD") cc = find_cc(repository_ctx) From 474aba363a85d68fdd8f8ab8b1d9d749a91e6e7a Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne <skyewm@google.com> Date: Mon, 1 May 2017 07:56:28 -0800 Subject: [PATCH 09/38] Make FunctionLibraryDefinition::AddFunctionDef() check for conflicting op name This prevents a function from masking an existing op. Change: 154720287 --- tensorflow/core/framework/function.cc | 6 ++++++ tensorflow/core/framework/function_test.cc | 9 +++++++++ tensorflow/core/graph/graph.cc | 7 ------- tensorflow/core/graph/graph_test.cc | 15 +++++++-------- 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 8a7d96c38a9..c7311559241 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -882,6 +882,12 @@ Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) { fdef.signature().name(), " already exists in function library."); } + const OpDef* op_def; + if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) { + return errors::InvalidArgument( + "Cannot add function '", fdef.signature().name(), + "' because an op with the same name already exists."); + } ptr.reset(new FunctionDefAndOpRegistration(fdef)); return Status::OK(); } diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc index efc6a2edcc5..07462a575e2 100644 --- a/tensorflow/core/framework/function_test.cc +++ b/tensorflow/core/framework/function_test.cc @@ -944,6 +944,15 @@ TEST(FunctionLibraryDefinitionTest, AddFunctionDef) { ASSERT_NE(second, nullptr); EXPECT_EQ(second->DebugString(), test::function::WXPlusB().signature().DebugString()); + + // Can't add function with same name as existing op + FunctionDef fdef = test::function::XTimesTwo(); + fdef.mutable_signature()->set_name("Add"); + Status s = lib_def.AddFunctionDef(fdef); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_message(), + "Cannot add function 'Add' because an op with the same name " + "already exists."); } TEST(FunctionLibraryDefinitionTest, AddGradientDef) { diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index fae9f26f762..e1657cb8622 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -380,13 +380,6 @@ Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) { // Ignore duplicate FunctionDefs continue; } - // TODO(skyewm): fix test breakages and reenable this check - // const OpDef* op_def; - // if (ops_.LookUpOpDef(fdef.signature().name(), &op_def).ok()) { - // return errors::InvalidArgument( - // "Cannot add function '", fdef.signature().name(), - // "' because an op with the same name already exists."); - // } TF_RETURN_IF_ERROR(ops_.AddFunctionDef(fdef)); } for (const GradientDef& grad : fdef_lib.gradient()) { diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc index 739ad90efd2..502b7b26da8 100644 --- a/tensorflow/core/graph/graph_test.cc +++ b/tensorflow/core/graph/graph_test.cc @@ -412,15 +412,14 @@ TEST_F(GraphTest, AddFunctionLibrary) { "Cannot add function 'XTimesTwo' because a different function with " "the same name already exists."); - // TODO(skyewm): reenable along with duplicate op check // Function with same name as an existing op triggers an error - // error_proto = proto; - // error_proto.mutable_function(0)->mutable_signature()->set_name("Add"); - // s = graph_.AddFunctionLibrary(error_proto); - // EXPECT_FALSE(s.ok()); - // EXPECT_EQ(s.error_message(), - // "Cannot add function 'Add' because an op with the same name " - // "already exists."); + error_proto = proto; + error_proto.mutable_function(0)->mutable_signature()->set_name("Add"); + s = graph_.AddFunctionLibrary(error_proto); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_message(), + "Cannot add function 'Add' because an op with the same name " + "already exists."); // Adding a gradient function to an existing function is ok GradientDef* grad = proto.add_gradient(); From b1761a2b4309b43a1245beb947677da1e80681e5 Mon Sep 17 00:00:00 2001 From: Shanqing Cai <cais@google.com> Date: Mon, 1 May 2017 07:58:55 -0800 Subject: [PATCH 10/38] tfdbg: improve the way in which the no-debugger-linked error is raised Change: 154720434 --- .../debugger_state_interface.cc | 33 ++++++++++++++----- .../common_runtime/debugger_state_interface.h | 14 ++++---- .../core/common_runtime/direct_session.cc | 24 ++++---------- .../core/distributed_runtime/graph_mgr.cc | 11 ++----- .../distributed_runtime/master_session.cc | 12 ++----- 5 files changed, 44 insertions(+), 50 deletions(-) diff --git a/tensorflow/core/common_runtime/debugger_state_interface.cc b/tensorflow/core/common_runtime/debugger_state_interface.cc index 73157ca05de..c1a92f9a221 100644 --- a/tensorflow/core/common_runtime/debugger_state_interface.cc +++ b/tensorflow/core/common_runtime/debugger_state_interface.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/common_runtime/debugger_state_interface.h" +#include "tensorflow/core/lib/core/errors.h" + namespace tensorflow { // static @@ -58,11 +60,17 @@ void DebuggerStateRegistry::RegisterFactory( } // static -std::unique_ptr<DebuggerStateInterface> DebuggerStateRegistry::CreateState( - const DebugOptions& debug_options) { - return (factory_ == nullptr || *factory_ == nullptr) - ? nullptr - : (*factory_)(debug_options); +Status DebuggerStateRegistry::CreateState( + const DebugOptions& debug_options, + std::unique_ptr<DebuggerStateInterface>* state) { + if (factory_ == nullptr || *factory_ == nullptr) { + return errors::Internal( + "Creation of debugger state failed. " + "It appears that TFDBG is not linked in this TensorFlow build."); + } else { + *state = (*factory_)(debug_options); + return Status::OK(); + } } // static @@ -73,10 +81,17 @@ void DebugGraphDecoratorRegistry::RegisterFactory( } // static -std::unique_ptr<DebugGraphDecoratorInterface> -DebugGraphDecoratorRegistry::CreateDecorator(const DebugOptions& options) { - return (factory_ == nullptr || *factory_ == nullptr) ? nullptr - : (*factory_)(options); +Status DebugGraphDecoratorRegistry::CreateDecorator( + const DebugOptions& options, + std::unique_ptr<DebugGraphDecoratorInterface>* decorator) { + if (factory_ == nullptr || *factory_ == nullptr) { + return errors::Internal( + "Creation of graph decorator failed. " + "It appears that TFDBG is not linked in this TensorFlow build."); + } else { + *decorator = (*factory_)(options); + return Status::OK(); + } } } // end namespace tensorflow diff --git a/tensorflow/core/common_runtime/debugger_state_interface.h b/tensorflow/core/common_runtime/debugger_state_interface.h index d182ce7092d..6f197f372fa 100644 --- a/tensorflow/core/common_runtime/debugger_state_interface.h +++ b/tensorflow/core/common_runtime/debugger_state_interface.h @@ -83,11 +83,12 @@ class DebuggerStateRegistry { // implementation based on DebugOptions. static void RegisterFactory(const DebuggerStateFactory& factory); - // If RegisterFactory() has been called, creates and returns a concrete + // If RegisterFactory() has been called, creates and supplies a concrete // DebuggerStateInterface implementation using the registered factory, - // owned by the caller. Otherwise returns nullptr. - static std::unique_ptr<DebuggerStateInterface> CreateState( - const DebugOptions& debug_options); + // owned by the caller and return an OK Status. Otherwise returns an error + // Status. + static Status CreateState(const DebugOptions& debug_options, + std::unique_ptr<DebuggerStateInterface>* state); private: static DebuggerStateFactory* factory_; @@ -103,8 +104,9 @@ class DebugGraphDecoratorRegistry { public: static void RegisterFactory(const DebugGraphDecoratorFactory& factory); - static std::unique_ptr<DebugGraphDecoratorInterface> CreateDecorator( - const DebugOptions& options); + static Status CreateDecorator( + const DebugOptions& options, + std::unique_ptr<DebugGraphDecoratorInterface>* decorator); private: static DebugGraphDecoratorFactory* factory_; diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 7c017f9584a..f1c17d77886 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -376,31 +376,19 @@ Status DirectSession::CreateDebuggerState( const std::vector<string>& output_names, const std::vector<string>& target_names, std::unique_ptr<DebuggerStateInterface>* debugger_state) { - std::unique_ptr<DebuggerStateInterface> state = - DebuggerStateRegistry::CreateState(debug_options); - if (!state) { - return errors::Internal( - "Debugger options are set, but creation of debugger state failed. " - "It appears that debugger is not linked in this TensorFlow build."); - } - - TF_RETURN_IF_ERROR(state->PublishDebugMetadata( + TF_RETURN_IF_ERROR( + DebuggerStateRegistry::CreateState(debug_options, debugger_state)); + TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata( debug_options.global_step(), session_run_count, executor_step_count, input_names, output_names, target_names)); - - *debugger_state = std::move(state); return Status::OK(); } Status DirectSession::DecorateAndPublishGraphForDebug( const DebugOptions& debug_options, Graph* graph, Device* device) { - std::unique_ptr<DebugGraphDecoratorInterface> decorator = - DebugGraphDecoratorRegistry::CreateDecorator(debug_options); - if (!decorator) { - return errors::Internal( - "Debugger options are set, but creation of debug graph publisher ", - "failed."); - } + std::unique_ptr<DebugGraphDecoratorInterface> decorator; + TF_RETURN_IF_ERROR( + DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator)); TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device)); TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph)); diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 6051f11fadb..ce7ce372e85 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -97,16 +97,11 @@ static Status ValidateGraphDefForDevices(const GraphDef& gdef) { Status GraphMgr::DecorateAndPublishGraphForDebug( const DebugOptions& debug_options, Graph* graph, Device* device) { - std::unique_ptr<DebugGraphDecoratorInterface> decorator = - DebugGraphDecoratorRegistry::CreateDecorator(debug_options); - if (!decorator) { - return errors::Internal( - "Debugger options are set, but creation of debug graph publisher ", - "failed."); - } + std::unique_ptr<DebugGraphDecoratorInterface> decorator; + TF_RETURN_IF_ERROR( + DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator)); TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device)); TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph)); - return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 681933adad3..f7b422b70e3 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -1337,13 +1337,8 @@ Status MasterSession::CreateDebuggerState( const DebugOptions& debug_options, const RunStepRequestWrapper& req, int64 rcg_execution_count, std::unique_ptr<DebuggerStateInterface>* debugger_state) { - std::unique_ptr<DebuggerStateInterface> state = - DebuggerStateRegistry::CreateState(debug_options); - if (!state) { - return errors::Internal( - "Debugger options are set, but creation of debugger state failed. " - "It appears that debugger is not linked in this TensorFlow build."); - } + TF_RETURN_IF_ERROR( + DebuggerStateRegistry::CreateState(debug_options, debugger_state)); std::vector<string> input_names; for (size_t i = 0; i < req.num_feeds(); ++i) { @@ -1362,11 +1357,10 @@ Status MasterSession::CreateDebuggerState( // While this counter value is straightforward to define and obtain for // DirectSessions, it is less so for non-direct Sessions. Devise a better // way to get its value when the need arises. - TF_RETURN_IF_ERROR(state->PublishDebugMetadata( + TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata( debug_options.global_step(), -1, rcg_execution_count, input_names, output_names, target_names)); - *debugger_state = std::move(state); return Status::OK(); } From fcfb5482da6da3221f8ea843829c24a289c4a4c8 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne <skyewm@google.com> Date: Mon, 1 May 2017 08:11:42 -0800 Subject: [PATCH 11/38] Update "Using TensorFlow via Docker" README to mention memory limits. Fixes #7861 Change: 154721550 --- tensorflow/tools/ci_build/README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorflow/tools/ci_build/README.md b/tensorflow/tools/ci_build/README.md index 1fa618e698f..ad83669950f 100644 --- a/tensorflow/tools/ci_build/README.md +++ b/tensorflow/tools/ci_build/README.md @@ -34,7 +34,11 @@ run continuous integration [ci.tensorflow.org](https://ci.tensorflow.org). ```bash tensorflow/tools/ci_build/ci_build.sh CPU bazel test //tensorflow/... ``` - + If you are using the Docker image on Windows or OS X, the Docker VM's default + memory limit may be too low to build TensorFlow. This can result in + strange-looking errors, e.g. the compilation may fail with `gcc: internal + compiler error: Killed (program cc1plus)`. Try increasing the memory limit in + the Docker preferences. ## Jobs From 00544564b9438a220c79fb554d39cb0e212bfd27 Mon Sep 17 00:00:00 2001 From: Martin Wicke <wicke@google.com> Date: Mon, 1 May 2017 09:02:52 -0800 Subject: [PATCH 12/38] Remove outdated README. Change: 154726428 --- .../contrib/learn/python/learn/README.md | 247 ------------------ 1 file changed, 247 deletions(-) delete mode 100644 tensorflow/contrib/learn/python/learn/README.md diff --git a/tensorflow/contrib/learn/python/learn/README.md b/tensorflow/contrib/learn/python/learn/README.md deleted file mode 100644 index 6a7b0ea6141..00000000000 --- a/tensorflow/contrib/learn/python/learn/README.md +++ /dev/null @@ -1,247 +0,0 @@ -# TF Learn - -TF Learn is a simplified interface for TensorFlow, to get people started on predictive analytics and data mining. The library covers a variety of needs: from linear models to *Deep Learning* applications like text and image understanding. - -### Why *TensorFlow*? - -* TensorFlow provides a good backbone for building different shapes of machine learning applications. -* It will continue to evolve both in the distributed direction and as general pipelinining machinery. - -### Why *TensorFlow Learn*? - -- To smooth the transition from the [scikit-learn](http://scikit-learn.org/stable/) world of one-liner machine learning into the more open world of building different shapes of ML models. You can start by using [fit](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/Estimator#fit)/[predict](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/Estimator#predict) and slide into TensorFlow APIs as you are getting comfortable. -- To provide a set of reference models that will be easy to integrate with existing code. - -## Installation - -[Install TensorFlow](https://www.tensorflow.org/install/), and then simply import `learn` via `from tensorflow.contrib.learn` or use `tf.contrib.learn`. - -Optionally you can install [scikit-learn](http://scikit-learn.org/stable/) and [pandas](http://pandas.pydata.org/) for additional functionality. - -### Tutorials - -- [TF Learn Quickstart](https://www.tensorflow.org/get_started/tflearn). Build, - train, and evaluate a neural network with just a few lines of code. -- [Input Functions](https://www.tensorflow.org/get_started/input_fn). Learn how - to create input functions to feed data into your models. -- [Linear Model](https://www.tensorflow.org/tutorials/wide). Learn the basics - of building linear models. -- [Wide and Deep Learning](https://www.tensorflow.org/tutorials/wide_and_deep). - Jointly train a linear model and a deep neural network. -- [Logging and Monitoring](https://www.tensorflow.org/get_started/monitors). - Use the Monitor API to audit training of a neural network. -- [Custom Estimators](https://www.tensorflow.org/extend/estimators). Learn - how to create a custom estimator. -- More coming soon. - -### Community - -- Twitter [#tensorflow](https://twitter.com/search?q=tensorflow&src=typd). -- StackOverflow with [tensorflow tag](http://stackoverflow.com/questions/tagged/tensorflow) for questions and struggles. -- GitHub [issues](https://github.com/tensorflow/tensorflow/issues) for technical discussions and feature requests. - -### Existing Estimator Implementations - -- [`LinearClassifier`](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/linear.py) - ([docs](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/LinearClassifier)) -- [`LinearRegressor`](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/linear.py) - ([docs](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/LinearRegressor)) -- [`DNNClassifier`](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/dnn.py) - ([docs](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/DNNClassifier)) -- [`DNNRegressor`](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/dnn.py) - ([docs](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/DNNRegressor)) -- [`DNNLinearCombinedClassifier`](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py) - ([docs](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/DNNLinearCombinedClassifier)) -- [`DNNLinearCombinedRegressor`](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py) - ([docs](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/DNNLinearCombinedRegressor)) -- [`SVM`](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/svm.py) - ([docs](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/g3doc/svm.md)) -- [`GMM`](https://www.tensorflow.org/code/tensorflow/contrib/factorization/python/ops/gmm.py) - ([docs](https://www.tensorflow.org/code/tensorflow/contrib/factorization/g3doc/gmm.md)) -- [`KMeansClustering`](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/kmeans.py) - ([docs](https://www.tensorflow.org/code/tensorflow/contrib/factorization/g3doc/kmeans.md)) - -### Usage Examples - -Below are a few simple examples of the API. For more examples, please see [examples](https://www.tensorflow.org/code/tensorflow/examples/learn). - -General tips: - -- It's useful to rescale a dataset to 0 mean and unit standard deviation before passing it to an [`Estimator`](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/Estimator). [Stochastic Gradient Descent](https://en.wikipedia.org/wiki/Stochastic_gradient_descent) doesn't always do the right thing when variable are at very different scales. - -- Categorical variables should be managed before passing input to the estimator. - -## Linear Classifier - -Simple linear classification: - -```python -import tensorflow.contrib.learn.python.learn as learn -from sklearn import datasets, metrics - -iris = datasets.load_iris() -feature_columns = learn.infer_real_valued_columns_from_input(iris.data) -classifier = learn.LinearClassifier(n_classes=3, feature_columns=feature_columns) -classifier.fit(iris.data, iris.target, steps=200, batch_size=32) -iris_predictions = list(classifier.predict(iris.data, as_iterable=True)) -score = metrics.accuracy_score(iris.target, iris_predictions) -print("Accuracy: %f" % score) -``` - -## Linear Regressor - -Simple linear regression: - -```python -import tensorflow.contrib.learn.python.learn as learn -from sklearn import datasets, metrics, preprocessing - -boston = datasets.load_boston() -x = preprocessing.StandardScaler().fit_transform(boston.data) -feature_columns = learn.infer_real_valued_columns_from_input(x) -regressor = learn.LinearRegressor(feature_columns=feature_columns) -regressor.fit(x, boston.target, steps=200, batch_size=32) -boston_predictions = list(regressor.predict(x, as_iterable=True)) -score = metrics.mean_squared_error(boston_predictions, boston.target) -print ("MSE: %f" % score) -``` - -## Deep Neural Network - -Example of 3 layer network with 10, 20 and 10 hidden units respectively: - -```python -import tensorflow.contrib.learn.python.learn as learn -from sklearn import datasets, metrics - -iris = datasets.load_iris() -feature_columns = learn.infer_real_valued_columns_from_input(iris.data) -classifier = learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3, feature_columns=feature_columns) -classifier.fit(iris.data, iris.target, steps=200, batch_size=32) -iris_predictions = list(classifier.predict(iris.data, as_iterable=True)) -score = metrics.accuracy_score(iris.target, iris_predictions) -print("Accuracy: %f" % score) -``` - -## Custom model - -Example of how to pass a custom model to the Estimator: - -```python -from sklearn import datasets -from sklearn import metrics -import tensorflow as tf -import tensorflow.contrib.layers.python.layers as layers -import tensorflow.contrib.learn.python.learn as learn - -iris = datasets.load_iris() - -def my_model(features, labels): - """DNN with three hidden layers.""" - # Convert the labels to a one-hot tensor of shape (length of features, 3) and - # with a on-value of 1 for each one-hot vector of length 3. - labels = tf.one_hot(labels, 3, 1, 0) - - # Create three fully connected layers respectively of size 10, 20, and 10. - features = layers.stack(features, layers.fully_connected, [10, 20, 10]) - - # Create two tensors respectively for prediction and loss. - prediction, loss = ( - tf.contrib.learn.models.logistic_regression(features, labels) - ) - - # Create a tensor for training op. - train_op = tf.contrib.layers.optimize_loss( - loss, tf.contrib.framework.get_global_step(), optimizer='Adagrad', - learning_rate=0.1) - - return {'class': tf.argmax(prediction, 1), 'prob': prediction}, loss, train_op - -classifier = learn.Estimator(model_fn=my_model) -classifier.fit(iris.data, iris.target, steps=1000) - -y_predicted = [ - p['class'] for p in classifier.predict(iris.data, as_iterable=True)] -score = metrics.accuracy_score(iris.target, y_predicted) -print('Accuracy: {0:f}'.format(score)) -``` - -## Saving / Restoring models - -Each estimator supports a `model_dir` argument, which takes a folder path where all model information will be saved: - -```python -classifier = learn.DNNClassifier(..., model_dir="/tmp/my_model") -``` - -If you run multiple `fit` operations on the same `Estimator`, training will resume where the last operation left off, e.g.: - -<pre><strong>classifier = learn.DNNClassifier(..., model_dir="/tmp/my_model") -classifier.fit(..., steps=300)</strong> -INFO:tensorflow:Create CheckpointSaverHook -INFO:tensorflow:loss = 2.40115, step = 1 -INFO:tensorflow:Saving checkpoints for 1 into /tmp/leftoff/model.ckpt. -INFO:tensorflow:loss = 0.338706, step = 101 -INFO:tensorflow:loss = 0.159414, step = 201 -INFO:tensorflow:Saving checkpoints for 300 into /tmp/leftoff/model.ckpt. -INFO:tensorflow:Loss for final step: 0.0953846. - -<strong>classifier.fit(..., steps=300)</strong> -INFO:tensorflow:Create CheckpointSaverHook -INFO:tensorflow:loss = 0.113173, step = 301 -INFO:tensorflow:Saving checkpoints for 301 into /tmp/leftoff/model.ckpt. -INFO:tensorflow:loss = 0.175782, step = 401 -INFO:tensorflow:loss = 0.119735, step = 501 -INFO:tensorflow:Saving checkpoints for 600 into /tmp/leftoff/model.ckpt. -INFO:tensorflow:Loss for final step: 0.0518137.</pre> - -To restore checkpoints to a new `Estimator`, just pass it the same `model_dir` argument, e.g.: - -<pre><strong>classifier = learn.DNNClassifier(..., model_dir="/tmp/my_model") -classifier.fit(..., steps=300)</strong> -INFO:tensorflow:Create CheckpointSaverHook -INFO:tensorflow:loss = 1.16335, step = 1 -INFO:tensorflow:Saving checkpoints for 1 into /tmp/leftoff/model.ckpt. -INFO:tensorflow:loss = 0.176995, step = 101 -INFO:tensorflow:loss = 0.184573, step = 201 -INFO:tensorflow:Saving checkpoints for 300 into /tmp/leftoff/model.ckpt. -INFO:tensorflow:Loss for final step: 0.0512496. - -<strong>classifier2 = learn.DNNClassifier(..., model_dir="/tmp/my_model") -classifier2.fit(..., steps=300)</strong> -INFO:tensorflow:Create CheckpointSaverHook -INFO:tensorflow:loss = 0.0543797, step = 301 -INFO:tensorflow:Saving checkpoints for 301 into /tmp/leftoff/model.ckpt. -INFO:tensorflow:loss = 0.101036, step = 401 -INFO:tensorflow:loss = 0.137956, step = 501 -INFO:tensorflow:Saving checkpoints for 600 into /tmp/leftoff/model.ckpt. -INFO:tensorflow:Loss for final step: 0.0162506.</pre> - -## Summaries - -If you supply a `model_dir` argument to your `Estimator`s, TensorFlow will write summaries for ``loss`` and histograms for variables in this directory. (You can also add custom summaries in your custom model function by calling [Summary](https://www.tensorflow.org/api_guides/python/summary) operations.) - -To view the summaries in TensorBoard, run the following command, where `logdir` is the `model_dir` for your `Estimator`: - -```shell -tensorboard --logdir=/tmp/tf_examples/my_model_1 -``` - -and then load the reported URL. - -**Graph visualization** - - - -**Loss visualization** - - - -## More examples - -See the [examples folder](https://www.tensorflow.org/code/tensorflow/examples/learn) for: - -- An easy way to handle [categorical variables](https://www.tensorflow.org/code/tensorflow/examples/learn/text_classification.py) (words are just an example of a categorical variable) -- Text Classification: see examples for [RNN](https://www.tensorflow.org/code/tensorflow/examples/learn/text_classification_character_rnn.py) and [CNN](https://www.tensorflow.org/code/tensorflow/examples/learn/text_classification_character_cnn.py) on characters -- [Digit recognition using a CNN](https://www.tensorflow.org/code/tensorflow/examples/learn/mnist.py) -- And much more! From a642a6ecf6685a73e7923c94ff69db47d67aa079 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 1 May 2017 09:06:26 -0800 Subject: [PATCH 13/38] Enabling python configuration to use a remotely generated configuration. This will prepare TF for use of Bazel's new remote execution api. Change: 154726835 --- third_party/py/python_configure.bzl | 17 +++++++++++++++-- third_party/py/remote.BUILD.tpl | 13 +++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) create mode 100644 third_party/py/remote.BUILD.tpl diff --git a/third_party/py/python_configure.bzl b/third_party/py/python_configure.bzl index 460a9b5d057..928bd333c61 100644 --- a/third_party/py/python_configure.bzl +++ b/third_party/py/python_configure.bzl @@ -112,6 +112,7 @@ def _genrule(src_dir, genrule_name, command, outs): ' cmd = """\n' + command + ' """,\n' + + ' visibility = ["//visibility:private"],' + ')\n' ) @@ -157,7 +158,7 @@ def _get_numpy_include(repository_ctx, python_bin): return result.stdout.splitlines()[0] -def _create_python_repository(repository_ctx): +def _create_local_python_repository(repository_ctx): """Creates the repository containing files set up to build with Python.""" python_include = None numpy_include = None @@ -208,9 +209,20 @@ def _create_python_repository(repository_ctx): }) +def _create_remote_python_repository(repository_ctx): + """Creates pointers to a remotely configured repo set up to build with Python. + """ + _tpl(repository_ctx, "remote.BUILD", { + "%{REMOTE_PYTHON_REPO}": repository_ctx.attr.remote_config_repo, + }, "BUILD") + + def _python_autoconf_impl(repository_ctx): """Implementation of the python_autoconf repository rule.""" - _create_python_repository(repository_ctx) + if repository_ctx.attr.remote_config_repo != "": + _create_remote_python_repository(repository_ctx) + else: + _create_local_python_repository(repository_ctx) python_configure = repository_rule( @@ -219,6 +231,7 @@ python_configure = repository_rule( "local_checks": attr.bool(mandatory = False, default = True), "python_include": attr.string(mandatory = False), "numpy_include": attr.string(mandatory = False), + "remote_config_repo": attr.string(mandatory = False, default =""), }, environ = [ _PYTHON_BIN_PATH, diff --git a/third_party/py/remote.BUILD.tpl b/third_party/py/remote.BUILD.tpl new file mode 100644 index 00000000000..332508ec84c --- /dev/null +++ b/third_party/py/remote.BUILD.tpl @@ -0,0 +1,13 @@ +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +alias( + name = "python_headers", + actual = "@%{REMOTE_PYTHON_REPO}//:python_headers", +) + +alias( + name = "numpy_headers", + actual = "@%{REMOTE_PYTHON_REPO}//:numpy_headers", +) From 65283e269fac1306089303143daf550b7b1a6658 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 1 May 2017 09:50:31 -0800 Subject: [PATCH 14/38] Move bucketize op to math_ops. This exposes the method as math_op._bucketize, but does not expose it in tf public interface. The method is still exposed as tf.contrib.layers.bucketize. Change: 154731971 --- tensorflow/contrib/BUILD | 2 - .../contrib/cmake/tf_core_kernels.cmake | 2 - tensorflow/contrib/cmake/tf_core_ops.cmake | 1 - tensorflow/contrib/cmake/tf_python.cmake | 2 - tensorflow/contrib/layers/BUILD | 49 ------------------- tensorflow/contrib/layers/kernels/BUILD | 11 ----- .../contrib/layers/ops/bucketization_op.cc | 46 ----------------- .../layers/python/ops/bucketization_op.py | 17 ++----- tensorflow/core/kernels/BUILD | 7 +++ .../kernels/bucketize_op.cc} | 2 + tensorflow/core/ops/math_ops.cc | 31 ++++++++++++ tensorflow/python/kernel_tests/BUILD | 12 +++++ .../kernel_tests/bucketize_op_test.py} | 36 +++++++++++--- tensorflow/python/ops/hidden_ops.txt | 1 + tensorflow/python/ops/math_ops.py | 6 +++ 15 files changed, 91 insertions(+), 134 deletions(-) delete mode 100644 tensorflow/contrib/layers/ops/bucketization_op.cc rename tensorflow/{contrib/layers/kernels/bucketization_kernel.cc => core/kernels/bucketize_op.cc} (98%) rename tensorflow/{contrib/layers/python/kernel_tests/bucketization_op_test.py => python/kernel_tests/bucketize_op_test.py} (55%) diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index c6585082884..616ceb5ca17 100755 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -78,7 +78,6 @@ cc_library( "//tensorflow/contrib/batching:batch_ops_kernels", "//tensorflow/contrib/factorization/kernels:all_kernels", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels", - "//tensorflow/contrib/layers:bucketization_op_kernel", "//tensorflow/contrib/layers:sparse_feature_cross_op_kernel", "//tensorflow/contrib/nccl:nccl_kernels", "//tensorflow/contrib/tensor_forest:tensor_forest_kernels", @@ -93,7 +92,6 @@ cc_library( "//tensorflow/contrib/factorization:all_ops", "//tensorflow/contrib/framework:all_ops", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib", - "//tensorflow/contrib/layers:bucketization_op_op_lib", "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib", "//tensorflow/contrib/nccl:nccl_ops_op_lib", "//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib", diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 0c420a02534..a71d9c5869b 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -47,9 +47,7 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc" #"${tensorflow_source_dir}/tensorflow/contrib/ffmpeg/decode_audio_op.cc" #"${tensorflow_source_dir}/tensorflow/contrib/ffmpeg/encode_audio_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/bucketization_kernel.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc" - "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/bucketization_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc" "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_ops.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 2beb264a54e..2a19433a7b2 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -70,7 +70,6 @@ GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir GENERATE_CONTRIB_OP_LIBRARY(framework_variable "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(input_pipeline "${tensorflow_source_dir}/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(image "${tensorflow_source_dir}/tensorflow/contrib/image/ops/image_ops.cc") -GENERATE_CONTRIB_OP_LIBRARY(layers_bucketization "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/bucketization_op.cc") GENERATE_CONTRIB_OP_LIBRARY(layers_sparse_feature_cross "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc") GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(nccl "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 39f5edf25d3..53ebfbb57de 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -623,8 +623,6 @@ GENERATE_PYTHON_OP_LIB("contrib_input_pipeline_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/input_pipeline/ops/gen_input_pipeline_ops.py) GENERATE_PYTHON_OP_LIB("contrib_image_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/image/ops/gen_image_ops.py) -GENERATE_PYTHON_OP_LIB("contrib_layers_bucketization_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/layers/ops/gen_bucketization_op.py) GENERATE_PYTHON_OP_LIB("contrib_layers_sparse_feature_cross_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/layers/ops/gen_sparse_feature_cross_op.py) GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops" diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index fc33e4b49e0..aba8eabe10c 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -18,35 +18,6 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") -tf_custom_op_library( - # TODO(sibyl-Mooth6ku,ptucker): Understand why 'python/ops/_' is needed and fix it. - name = "python/ops/_bucketization_op.so", - srcs = [ - "ops/bucketization_op.cc", - ], - deps = [ - "//tensorflow/contrib/layers/kernels:bucketization_kernel", - ], -) - -tf_gen_op_libs( - op_lib_names = ["bucketization_op"], -) - -tf_gen_op_wrapper_py( - name = "bucketization_op", - deps = [":bucketization_op_op_lib"], -) - -tf_kernel_library( - name = "bucketization_op_kernel", - deps = [ - "//tensorflow/contrib/layers/kernels:bucketization_kernel", - "//tensorflow/core:framework", - ], - alwayslink = 1, -) - tf_custom_op_library( # TODO(sibyl-Mooth6ku,ptucker): Understand why 'python/ops/_' is needed and fix it. name = "python/ops/_sparse_feature_cross_op.so", @@ -97,18 +68,14 @@ tf_custom_op_py_library( "python/ops/sparse_ops.py", ], dso = [ - ":python/ops/_bucketization_op.so", ":python/ops/_sparse_feature_cross_op.so", ], kernels = [ - ":bucketization_op_kernel", ":sparse_feature_cross_op_kernel", - ":bucketization_op_op_lib", ":sparse_feature_cross_op_op_lib", ], srcs_version = "PY2AND3", deps = [ - ":bucketization_op", ":sparse_feature_cross_op", "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/lookup:lookup_py", @@ -315,22 +282,6 @@ py_test( ], ) -py_test( - name = "bucketization_op_test", - size = "small", - srcs = ["python/kernel_tests/bucketization_op_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":layers_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - "//third_party/py/numpy", - ], -) - py_test( name = "sparse_feature_cross_op_test", size = "medium", diff --git a/tensorflow/contrib/layers/kernels/BUILD b/tensorflow/contrib/layers/kernels/BUILD index 7a2d6d8c4f7..15b984f9389 100644 --- a/tensorflow/contrib/layers/kernels/BUILD +++ b/tensorflow/contrib/layers/kernels/BUILD @@ -7,17 +7,6 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) -cc_library( - name = "bucketization_kernel", - srcs = ["bucketization_kernel.cc"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf//:protobuf_headers", - ], - alwayslink = 1, -) - cc_library( name = "sparse_feature_cross_kernel", srcs = ["sparse_feature_cross_kernel.cc"], diff --git a/tensorflow/contrib/layers/ops/bucketization_op.cc b/tensorflow/contrib/layers/ops/bucketization_op.cc deleted file mode 100644 index d90d47a1eb4..00000000000 --- a/tensorflow/contrib/layers/ops/bucketization_op.cc +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -REGISTER_OP("Bucketize") - .Input("input: T") - .Output("output: int32") - .Attr("T: {int32, int64, float, double}") - .Attr("boundaries: list(float)") - .SetShapeFn(shape_inference::UnchangedShape) - .Doc(R"doc( -Bucketizes 'input' based on 'boundaries'. - -For example, if the inputs are - boundaries = [0, 10, 100] - input = [[-5, 10000] - [150, 10] - [5, 100]] - -then the output will be - output = [[0, 3] - [3, 2] - [1, 3]] - -input: Any shape of Tensor contains with int or float type. -boundaries: A sorted list of floats gives the boundary of the buckets. -output: Same shape with 'input', each value of input replaced with bucket index. - -)doc"); -} // namespace tensorflow diff --git a/tensorflow/contrib/layers/python/ops/bucketization_op.py b/tensorflow/contrib/layers/python/ops/bucketization_op.py index b941a9b8222..f498352855f 100644 --- a/tensorflow/contrib/layers/python/ops/bucketization_op.py +++ b/tensorflow/contrib/layers/python/ops/bucketization_op.py @@ -17,13 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.layers.ops import gen_bucketization_op -from tensorflow.contrib.util import loader -from tensorflow.python.framework import ops -from tensorflow.python.platform import resource_loader - -_bucketization_op = loader.load_op_library( - resource_loader.get_path_to_datafile("_bucketization_op.so")) +from tensorflow.python.ops import math_ops def bucketize(input_tensor, boundaries, name=None): @@ -43,10 +37,5 @@ def bucketize(input_tensor, boundaries, name=None): Raises: TypeError: If boundaries is not a list. """ - if not isinstance(boundaries, list): - raise TypeError("boundaries must be a list") - - return gen_bucketization_op.bucketize(input_tensor, boundaries, name=name) - - -ops.NotDifferentiable("Bucketize") + return math_ops._bucketize( # pylint: disable=protected-access + input_tensor, boundaries=boundaries, name=name) diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 8140f5d67ea..231e06d5f4d 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -2249,6 +2249,7 @@ cc_library( ":batch_matmul_op", ":betainc_op", ":bincount_op", + ":bucketize_op", ":cast_op", ":check_numerics_op", ":cross_op", @@ -2286,6 +2287,12 @@ tf_kernel_library( deps = MATH_DEPS, ) +tf_kernel_library( + name = "bucketize_op", + prefix = "bucketize_op", + deps = MATH_DEPS, +) + tf_kernel_library( name = "cast_op", prefix = "cast_op", diff --git a/tensorflow/contrib/layers/kernels/bucketization_kernel.cc b/tensorflow/core/kernels/bucketize_op.cc similarity index 98% rename from tensorflow/contrib/layers/kernels/bucketization_kernel.cc rename to tensorflow/core/kernels/bucketize_op.cc index 5cfa39de764..93c2d01221f 100644 --- a/tensorflow/contrib/layers/kernels/bucketization_kernel.cc +++ b/tensorflow/core/kernels/bucketize_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// See docs in ../ops/math_ops.cc. + #include <algorithm> #include <vector> diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 0942e8ae755..aafbd1b5cce 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -2356,4 +2356,35 @@ output_max: the computed max output. )doc"); +// -------------------------------------------------------------------------- + +REGISTER_OP("Bucketize") + .Input("input: T") + .Output("output: int32") + .Attr("T: {int32, int64, float, double}") + .Attr("boundaries: list(float)") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Bucketizes 'input' based on 'boundaries'. + +For example, if the inputs are + boundaries = [0, 10, 100] + input = [[-5, 10000] + [150, 10] + [5, 100]] + +then the output will be + output = [[0, 3] + [3, 2] + [1, 3]] + +input: Any shape of Tensor contains with int or float type. +boundaries: A sorted list of floats gives the boundary of the buckets. +output: Same shape with 'input', each value of input replaced with bucket index. + +@compatibility(numpy) +Equivalent to np.digitize. +@end_compatibility +)doc"); + } // namespace tensorflow diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 06a0aa468a1..a81e6f54827 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -2569,6 +2569,18 @@ tf_py_test( ], ) +tf_py_test( + name = "bucketize_op_test", + size = "small", + srcs = ["bucketize_op_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/layers/python/kernel_tests/bucketization_op_test.py b/tensorflow/python/kernel_tests/bucketize_op_test.py similarity index 55% rename from tensorflow/contrib/layers/python/kernel_tests/bucketization_op_test.py rename to tensorflow/python/kernel_tests/bucketize_op_test.py index abc6cc5674c..ed53cc62940 100644 --- a/tensorflow/contrib/layers/python/kernel_tests/bucketization_op_test.py +++ b/tensorflow/python/kernel_tests/bucketize_op_test.py @@ -12,35 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for bucketization_op.""" +"""Tests for bucketize_op.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.layers.python.ops import bucketization_op from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test class BucketizationOpTest(test.TestCase): - def test_normal_usecase(self): - op = bucketization_op.bucketize( + def testInt(self): + op = math_ops._bucketize( constant_op.constant([-5, 0, 2, 3, 5, 8, 10, 11, 12]), boundaries=[0, 3, 8, 11]) expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4] with self.test_session() as sess: self.assertAllEqual(expected_out, sess.run(op)) - def test_invalid_boundaries_order(self): - op = bucketization_op.bucketize( + def testFloat(self): + op = math_ops._bucketize( + constant_op.constant([-5., 0., 2., 3., 5., 8., 10., 11., 12.]), + boundaries=[0., 3., 8., 11.]) + expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4] + with self.test_session() as sess: + self.assertAllEqual(expected_out, sess.run(op)) + + def test2DInput(self): + op = math_ops._bucketize( + constant_op.constant([[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]), + boundaries=[0, 3, 8, 11]) + expected_out = [[0, 1, 1, 2, 2], [3, 3, 4, 4, 1]] + with self.test_session() as sess: + self.assertAllEqual(expected_out, sess.run(op)) + + def testInvalidBoundariesOrder(self): + op = math_ops._bucketize( constant_op.constant([-5, 0]), boundaries=[0, 8, 3, 11]) with self.test_session() as sess: - with self.assertRaises(errors_impl.InvalidArgumentError): + with self.assertRaisesRegexp( + errors_impl.InvalidArgumentError, "Expected sorted boundaries"): sess.run(op) + def testBoundariesNotList(self): + with self.assertRaisesRegexp( + TypeError, "Expected list for attr boundaries"): + math_ops._bucketize(constant_op.constant([-5, 0]), boundaries=0) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt index f302477068b..9022e1453df 100644 --- a/tensorflow/python/ops/hidden_ops.txt +++ b/tensorflow/python/ops/hidden_ops.txt @@ -220,6 +220,7 @@ BatchFFT3D BatchIFFT BatchIFFT2D BatchIFFT3D +Bucketize Complex ComplexAbs Conj diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 572e7b71db5..92e8723ced8 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -241,6 +241,12 @@ def abs(x, name=None): # pylint: enable=g-docstring-has-escape +# pylint: disable=redefined-builtin +def _bucketize(input, boundaries, name=None): + return gen_math_ops._bucketize(input=input, boundaries=boundaries, name=name) +# pylint: enable=redefined-builtin + + class DivideDelegateWithName(object): """Use Python2/Python3 division delegation to implement divide for tensors.""" From 148df55690df08ad0ca9510bbce601f467c10b81 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 1 May 2017 10:00:08 -0800 Subject: [PATCH 15/38] Go: Update generated wrapper functions for TensorFlow ops. Change: 154733347 --- tensorflow/go/op/wrappers.go | 38 ++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 74f05c86382..001e5f9feb2 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -18744,6 +18744,44 @@ func SparseReshape(scope *Scope, input_indices tf.Output, input_shape tf.Output, return op.Output(0), op.Output(1) } +// Bucketizes 'input' based on 'boundaries'. +// +// For example, if the inputs are +// boundaries = [0, 10, 100] +// input = [[-5, 10000] +// [150, 10] +// [5, 100]] +// +// then the output will be +// output = [[0, 3] +// [3, 2] +// [1, 3]] +// +// Arguments: +// input: Any shape of Tensor contains with int or float type. +// boundaries: A sorted list of floats gives the boundary of the buckets. +// +// Returns Same shape with 'input', each value of input replaced with bucket index. +// +// @compatibility(numpy) +// Equivalent to np.digitize. +// @end_compatibility +func Bucketize(scope *Scope, input tf.Output, boundaries []float32) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"boundaries": boundaries} + opspec := tf.OpSpec{ + Type: "Bucketize", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Computes the product along segments of a tensor. // // Read @{$math_ops#segmentation$the section on segmentation} for an explanation of From 0d0a36aa3c6ab72ef91a02f7ab301bafcd3bdf60 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 1 May 2017 10:22:04 -0800 Subject: [PATCH 16/38] call InitMain in hvx ops support checker Change: 154736553 --- .../hvx_ops_support_checker/hvx_ops_support_checker_main.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc index e905d4b3224..03b4dcbfa8b 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/tools/graph_transforms/transform_utils.h" @@ -36,6 +37,9 @@ static int ParseFlags(int argc, char* argv[], string* in_graph) { Flag("in_graph", in_graph, "input graph file name"), }; CHECK(Flags::Parse(&argc, argv, flag_list)); + // We need to call this to set up global state for TensorFlow. + port::InitMain(argv[0], &argc, &argv); + string usage = Flags::Usage(argv[0], flag_list); CHECK(!in_graph->empty()) << "in_graph graph can't be empty.\n" << usage; From 24b049877b8a056200f1c0d125d345c0af637aa1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 1 May 2017 10:33:33 -0800 Subject: [PATCH 17/38] Update ops-related pbtxt files. Change: 154738165 --- .../core/ops/compat/ops_history.v1.pbtxt | 27 ++++++++++++++++ tensorflow/core/ops/ops.pbtxt | 32 +++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 8434251f532..439b022036f 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -3592,6 +3592,33 @@ op { } } } +op { + name: "Bucketize" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type: DT_INT32 + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "boundaries" + type: "list(float)" + } +} op { name: "CTCBeamSearchDecoder" input_arg { diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index bd963a52669..cb313a6c278 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -3633,6 +3633,38 @@ op { summary: "Return the reduction indices for computing gradients of s0 op s1 with broadcast." description: "This is typically used by gradient computations for a broadcasting operation." } +op { + name: "Bucketize" + input_arg { + name: "input" + description: "Any shape of Tensor contains with int or float type." + type_attr: "T" + } + output_arg { + name: "output" + description: "Same shape with \'input\', each value of input replaced with bucket index.\n\n@compatibility(numpy)\nEquivalent to np.digitize.\n@end_compatibility" + type: DT_INT32 + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "boundaries" + type: "list(float)" + description: "A sorted list of floats gives the boundary of the buckets." + } + summary: "Bucketizes \'input\' based on \'boundaries\'." + description: "For example, if the inputs are\n boundaries = [0, 10, 100]\n input = [[-5, 10000]\n [150, 10]\n [5, 100]]\n\nthen the output will be\n output = [[0, 3]\n [3, 2]\n [1, 3]]" +} op { name: "CTCBeamSearchDecoder" input_arg { From 03327190420dd5b1c34a5ffdd0000aff40980ed5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 1 May 2017 11:02:10 -0800 Subject: [PATCH 18/38] Adding HighwayWrapper for rnn cell that creates a highway skip connection between the cell's input and output activations. Change: 154741723 --- tensorflow/contrib/rnn/__init__.py | 1 + .../rnn/python/kernel_tests/rnn_cell_test.py | 24 ++++++ tensorflow/contrib/rnn/python/ops/rnn_cell.py | 83 +++++++++++++++++++ 3 files changed, 108 insertions(+) diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py index ce1ed7f491b..a7448781241 100644 --- a/tensorflow/contrib/rnn/__init__.py +++ b/tensorflow/contrib/rnn/__init__.py @@ -43,6 +43,7 @@ See @{$python/contrib.rnn} guide. @@BidirectionalGridLSTMCell @@NASCell @@PhasedLSTMCell +@@HighwayWrapper ### RNNCell wrappers @@AttentionCellWrapper diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index 8b40fc068fe..55fd7e7a51b 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -882,6 +882,30 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[1].c, expected_state_c) self.assertAllClose(res[1].h, expected_state_h) + def testHighwayWrapper(self): + with self.test_session() as sess: + with variable_scope.variable_scope( + "base_cell", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 3]) + m = array_ops.zeros([1, 3]) + base_cell = core_rnn_cell_impl.GRUCell(3) + g, m_new = base_cell(x, m) + with variable_scope.variable_scope( + "hw_cell", initializer=init_ops.constant_initializer(0.5)): + hw_cell = rnn_cell.HighwayWrapper( + core_rnn_cell_impl.GRUCell(3), carry_bias_init=-100.0) + g_res, m_new_res = hw_cell(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g, g_res, m_new, m_new_res], { + x: np.array([[1., 1., 1.]]), + m: np.array([[0.1, 0.1, 0.1]]) + }) + # As carry_bias_init is very negative, the carry gate is 'open' and the + # transform gate is 'closed'. This means the output equals the input. + self.assertAllClose(res[1], res[0]) + # States are left untouched + self.assertAllClose(res[2], res[3]) + class LayerNormBasicLSTMCellTest(test.TestCase): diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 83e8c2777f6..acba77f0e13 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -1157,6 +1157,89 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell): return new_attns, new_attn_states +class HighwayWrapper(core_rnn_cell.RNNCell): + """RNNCell wrapper that adds highway connection on cell input and output. + + Based on: + R. K. Srivastava, K. Greff, and J. Schmidhuber, "Highway networks", + arXiv preprint arXiv:1505.00387, 2015. + https://arxiv.org/abs/1505.00387 + """ + + def __init__(self, cell, + couple_carry_transform_gates=True, + carry_bias_init=1.0): + """Constructs a `HighwayWrapper` for `cell`. + + Args: + cell: An instance of `RNNCell`. + couple_carry_transform_gates: boolean, should the Carry and Transform gate + be coupled. + carry_bias_init: float, carry gates bias initialization. + """ + self._cell = cell + self._couple_carry_transform_gates = couple_carry_transform_gates + self._carry_bias_init = carry_bias_init + + @property + def state_size(self): + return self._cell.state_size + + @property + def output_size(self): + return self._cell.output_size + + def zero_state(self, batch_size, dtype): + with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): + return self._cell.zero_state(batch_size, dtype) + + def _highway(self, inp, out): + input_size = inp.get_shape().with_rank(2)[1].value + carry_weight = vs.get_variable("carry_w", [input_size, input_size]) + carry_bias = vs.get_variable( + "carry_b", [input_size], + initializer=init_ops.constant_initializer( + self._carry_bias_init)) + carry = math_ops.sigmoid(nn_ops.xw_plus_b(inp, carry_weight, carry_bias)) + if self._couple_carry_transform_gates: + transform = 1 - carry + else: + transform_weight = vs.get_variable("transform_w", + [input_size, input_size]) + transform_bias = vs.get_variable( + "transform_b", [input_size], + initializer=init_ops.constant_initializer( + -self._carry_bias_init)) + transform = math_ops.sigmoid(nn_ops.xw_plus_b(inp, + transform_weight, + transform_bias)) + return inp * carry + out * transform + + def __call__(self, inputs, state, scope=None): + """Run the cell and add its inputs to its outputs. + + Args: + inputs: cell inputs. + state: cell state. + scope: optional cell scope. + + Returns: + Tuple of cell outputs and new state. + + Raises: + TypeError: If cell inputs and outputs have different structure (type). + ValueError: If cell inputs and outputs have different structure (value). + """ + outputs, new_state = self._cell(inputs, state, scope=scope) + nest.assert_same_structure(inputs, outputs) + # Ensure shapes match + def assert_shape_match(inp, out): + inp.get_shape().assert_is_compatible_with(out.get_shape()) + nest.map_structure(assert_shape_match, inputs, outputs) + res_outputs = nest.map_structure(self._highway, inputs, outputs) + return (res_outputs, new_state) + + class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell): """LSTM unit with layer normalization and recurrent dropout. From 9528658bb51d3a4e14ec1014e8fdd8e0076805e8 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo <ebrevdo@google.com> Date: Mon, 1 May 2017 11:15:23 -0800 Subject: [PATCH 19/38] LookupTable ops go to V2, use resource handles. There's some duplicate documentation; yutaka will deprecate the old ops and remove the duplicate docs, update the shape tests to point to the new ops, and modify lookup_ops.py to use the _v2 ops. Change: 154743350 --- tensorflow/core/framework/resource_mgr.cc | 8 + tensorflow/core/framework/resource_mgr.h | 2 + .../core/kernels/lookup_table_init_op.cc | 13 +- tensorflow/core/kernels/lookup_table_op.cc | 51 ++- tensorflow/core/kernels/lookup_table_op.h | 72 ++-- tensorflow/core/kernels/lookup_util.cc | 50 ++- tensorflow/core/ops/data_flow_ops.cc | 312 ++++++++++++++++++ tensorflow/python/ops/hidden_ops.txt | 11 + 8 files changed, 470 insertions(+), 49 deletions(-) diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc index 7f9fe084ba4..ab7dd0c5475 100644 --- a/tensorflow/core/framework/resource_mgr.cc +++ b/tensorflow/core/framework/resource_mgr.cc @@ -246,6 +246,14 @@ ResourceHandle HandleFromInput(OpKernelContext* ctx, int input) { return ctx->input(input).flat<ResourceHandle>()(0); } +Status HandleFromInput(OpKernelContext* ctx, StringPiece input, + ResourceHandle* handle) { + const Tensor* tensor; + TF_RETURN_IF_ERROR(ctx->input(input, &tensor)); + *handle = tensor->flat<ResourceHandle>()(0); + return Status::OK(); +} + Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p)); return ctx->resource_manager()->Delete(p); diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index fe6e09378fd..26a5766569f 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -211,6 +211,8 @@ ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx, // Returns a resource handle from a numbered op input. ResourceHandle HandleFromInput(OpKernelContext* ctx, int input); +Status HandleFromInput(OpKernelContext* ctx, StringPiece input, + ResourceHandle* handle); // Create a resource pointed by a given resource handle. template <typename T> diff --git a/tensorflow/core/kernels/lookup_table_init_op.cc b/tensorflow/core/kernels/lookup_table_init_op.cc index bde1d0360a9..ada6fe8d950 100644 --- a/tensorflow/core/kernels/lookup_table_init_op.cc +++ b/tensorflow/core/kernels/lookup_table_init_op.cc @@ -367,7 +367,9 @@ class InitializeTableOp : public OpKernel { GetInitializableLookupTable("table_handle", ctx, &table)); core::ScopedUnref unref_me(table); - DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(), + DataType expected_input_0 = + (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; + DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(), table->value_dtype()}; DataTypeVector expected_outputs = {}; OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); @@ -408,6 +410,8 @@ class InitializeTableOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("InitializeTable").Device(DEVICE_CPU), InitializeTableOp); +REGISTER_KERNEL_BUILDER(Name("InitializeTableV2").Device(DEVICE_CPU), + InitializeTableOp); // Kernel to initialize a lookup table from a text file. // @@ -433,7 +437,9 @@ class InitializeTableFromTextFileOp : public OpKernel { GetInitializableLookupTable("table_handle", ctx, &table)); core::ScopedUnref unref_me(table); - DataTypeVector expected_inputs = {DT_STRING_REF, DT_STRING}; + DataType expected_input_0 = + (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; + DataTypeVector expected_inputs = {expected_input_0, DT_STRING}; DataTypeVector expected_outputs = {}; OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); @@ -472,5 +478,8 @@ class InitializeTableFromTextFileOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("InitializeTableFromTextFile").Device(DEVICE_CPU), InitializeTableFromTextFileOp); +REGISTER_KERNEL_BUILDER( + Name("InitializeTableFromTextFileV2").Device(DEVICE_CPU), + InitializeTableFromTextFileOp); } // namespace tensorflow diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc index 0a065e37d76..11ce2a71dcb 100644 --- a/tensorflow/core/kernels/lookup_table_op.cc +++ b/tensorflow/core/kernels/lookup_table_op.cc @@ -624,7 +624,10 @@ class LookupTableFindOp : public OpKernel { OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); core::ScopedUnref unref_me(table); - DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(), + // Input 0 could be a STRING_REF or a RESOURCE + DataType expected_input_0 = + (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; + DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(), table->value_dtype()}; DataTypeVector expected_outputs = {table->value_dtype()}; OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); @@ -647,6 +650,8 @@ class LookupTableFindOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("LookupTableFind").Device(DEVICE_CPU), LookupTableFindOp); +REGISTER_KERNEL_BUILDER(Name("LookupTableFindV2").Device(DEVICE_CPU), + LookupTableFindOp); // Table insert op. class LookupTableInsertOp : public OpKernel { @@ -658,7 +663,9 @@ class LookupTableInsertOp : public OpKernel { OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); core::ScopedUnref unref_me(table); - DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(), + DataType expected_input_0 = + (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; + DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(), table->value_dtype()}; OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); @@ -680,6 +687,8 @@ class LookupTableInsertOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("LookupTableInsert").Device(DEVICE_CPU), LookupTableInsertOp); +REGISTER_KERNEL_BUILDER(Name("LookupTableInsertV2").Device(DEVICE_CPU), + LookupTableInsertOp); // Op that returns the size of the given table. class LookupTableSizeOp : public OpKernel { @@ -699,6 +708,8 @@ class LookupTableSizeOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("LookupTableSize").Device(DEVICE_CPU), LookupTableSizeOp); +REGISTER_KERNEL_BUILDER(Name("LookupTableSizeV2").Device(DEVICE_CPU), + LookupTableSizeOp); // Op that outputs tensors of all keys and all values. class LookupTableExportOp : public OpKernel { @@ -716,6 +727,8 @@ class LookupTableExportOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("LookupTableExport").Device(DEVICE_CPU), LookupTableExportOp); +REGISTER_KERNEL_BUILDER(Name("LookupTableExportV2").Device(DEVICE_CPU), + LookupTableExportOp); // Clear the table and insert data. class LookupTableImportOp : public OpKernel { @@ -727,7 +740,9 @@ class LookupTableImportOp : public OpKernel { OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); core::ScopedUnref unref_me(table); - DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(), + DataType expected_input_0 = + (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; + DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(), table->value_dtype()}; OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); @@ -749,6 +764,8 @@ class LookupTableImportOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("LookupTableImport").Device(DEVICE_CPU), LookupTableImportOp); +REGISTER_KERNEL_BUILDER(Name("LookupTableImportV2").Device(DEVICE_CPU), + LookupTableImportOp); // Register the HashTable op with the currently supported key and value types. #define REGISTER_KERNEL(key_dtype, value_dtype) \ @@ -757,6 +774,13 @@ REGISTER_KERNEL_BUILDER(Name("LookupTableImport").Device(DEVICE_CPU), .Device(DEVICE_CPU) \ .TypeConstraint<key_dtype>("key_dtype") \ .TypeConstraint<value_dtype>("value_dtype"), \ + LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \ + value_dtype>) \ + REGISTER_KERNEL_BUILDER( \ + Name("HashTableV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<key_dtype>("key_dtype") \ + .TypeConstraint<value_dtype>("value_dtype"), \ LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \ value_dtype>) @@ -778,6 +802,13 @@ REGISTER_KERNEL(string, bool); .Device(DEVICE_CPU) \ .TypeConstraint<key_dtype>("key_dtype") \ .TypeConstraint<value_dtype>("value_dtype"), \ + LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \ + key_dtype, value_dtype>) \ + REGISTER_KERNEL_BUILDER( \ + Name("MutableHashTableV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<key_dtype>("key_dtype") \ + .TypeConstraint<value_dtype>("value_dtype"), \ LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \ key_dtype, value_dtype>) @@ -796,6 +827,13 @@ REGISTER_KERNEL(int64, float); .Device(DEVICE_CPU) \ .TypeConstraint<key_dtype>("key_dtype") \ .TypeConstraint<value_dtype>("value_dtype"), \ + LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \ + key_dtype, value_dtype>) \ + REGISTER_KERNEL_BUILDER( \ + Name("MutableHashTableOfTensorsV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<key_dtype>("key_dtype") \ + .TypeConstraint<value_dtype>("value_dtype"), \ LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \ key_dtype, value_dtype>) @@ -813,6 +851,13 @@ REGISTER_KERNEL(string, bool); .Device(DEVICE_CPU) \ .TypeConstraint<key_dtype>("key_dtype") \ .TypeConstraint<value_dtype>("value_dtype"), \ + LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \ + key_dtype, value_dtype>) \ + REGISTER_KERNEL_BUILDER( \ + Name("MutableDenseHashTableV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<key_dtype>("key_dtype") \ + .TypeConstraint<value_dtype>("value_dtype"), \ LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \ key_dtype, value_dtype>) diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h index ae253b4dc96..4cd25a3cc6f 100644 --- a/tensorflow/core/kernels/lookup_table_op.h +++ b/tensorflow/core/kernels/lookup_table_op.h @@ -51,40 +51,52 @@ class LookupTableOp : public OpKernel { // ctx is not owned by this function. void Compute(OpKernelContext* ctx) override { mutex_lock l(mu_); + if (!table_handle_set_) { OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(), use_node_name_sharing_)); - auto creator = [ctx, this](lookup::LookupInterface** ret) { - lookup::LookupInterface* container = new Container(ctx, this); - if (!ctx->status().ok()) { - container->Unref(); - return ctx->status(); - } - if (ctx->track_allocations()) { - ctx->record_device_persistent_memory_allocation( - container->MemoryUsed()); - } - *ret = container; - return Status::OK(); - }; - - lookup::LookupInterface* table = nullptr; - OP_REQUIRES_OK( - ctx, cinfo_.resource_manager() - ->template LookupOrCreate<lookup::LookupInterface>( - cinfo_.container(), cinfo_.name(), &table, creator)); - core::ScopedUnref unref_me(table); - - OP_REQUIRES_OK(ctx, lookup::CheckTableDataTypes( - *table, DataTypeToEnum<key_dtype>::v(), - DataTypeToEnum<value_dtype>::v(), cinfo_.name())); - - auto h = table_handle_.AccessTensor(ctx)->template flat<string>(); - h(0) = cinfo_.container(); - h(1) = cinfo_.name(); - table_handle_set_ = true; } - ctx->set_output_ref(0, &mu_, table_handle_.AccessTensor(ctx)); + + auto creator = [ctx, this](lookup::LookupInterface** ret) { + lookup::LookupInterface* container = new Container(ctx, this); + if (!ctx->status().ok()) { + container->Unref(); + return ctx->status(); + } + if (ctx->track_allocations()) { + ctx->record_device_persistent_memory_allocation( + container->MemoryUsed()); + } + *ret = container; + return Status::OK(); + }; + + lookup::LookupInterface* table = nullptr; + OP_REQUIRES_OK(ctx, + cinfo_.resource_manager() + ->template LookupOrCreate<lookup::LookupInterface>( + cinfo_.container(), cinfo_.name(), &table, creator)); + core::ScopedUnref unref_me(table); + + OP_REQUIRES_OK(ctx, lookup::CheckTableDataTypes( + *table, DataTypeToEnum<key_dtype>::v(), + DataTypeToEnum<value_dtype>::v(), cinfo_.name())); + + if (ctx->expected_output_dtype(0) == DT_RESOURCE) { + Tensor* handle; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); + handle->scalar<ResourceHandle>()() = + MakeResourceHandle<lookup::LookupInterface>(ctx, cinfo_.container(), + cinfo_.name()); + } else { + if (!table_handle_set_) { + auto h = table_handle_.AccessTensor(ctx)->template flat<string>(); + h(0) = cinfo_.container(); + h(1) = cinfo_.name(); + } + ctx->set_output_ref(0, &mu_, table_handle_.AccessTensor(ctx)); + } + table_handle_set_ = true; } ~LookupTableOp() override { diff --git a/tensorflow/core/kernels/lookup_util.cc b/tensorflow/core/kernels/lookup_util.cc index f87ce0e6b20..d0f269be231 100644 --- a/tensorflow/core/kernels/lookup_util.cc +++ b/tensorflow/core/kernels/lookup_util.cc @@ -49,26 +49,48 @@ Status GetLookupTable(const string& input_name, OpKernelContext* ctx, LookupInterface** table) { string container; string table_handle; - TF_RETURN_IF_ERROR( - GetTableHandle(input_name, ctx, &container, &table_handle)); - return ctx->resource_manager()->Lookup(container, table_handle, table); + DataType handle_dtype; + TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype)); + if (handle_dtype == DT_RESOURCE) { + ResourceHandle handle; + TF_RETURN_IF_ERROR(HandleFromInput(ctx, input_name, &handle)); + return LookupResource(ctx, handle, table); + } else { + TF_RETURN_IF_ERROR( + GetTableHandle(input_name, ctx, &container, &table_handle)); + return ctx->resource_manager()->Lookup(container, table_handle, table); + } } Status GetInitializableLookupTable(const string& input_name, OpKernelContext* ctx, InitializableLookupTable** table) { - string container; - string table_handle; - TF_RETURN_IF_ERROR( - GetTableHandle(input_name, ctx, &container, &table_handle)); LookupInterface* lookup_table; - TF_RETURN_IF_ERROR( - ctx->resource_manager()->Lookup(container, table_handle, &lookup_table)); - *table = lookup_table->GetInitializableLookupTable(); - if (*table == nullptr) { - lookup_table->Unref(); - return errors::InvalidArgument("Table ", container, " ", table_handle, - " is not initializable"); + DataType handle_dtype; + TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype)); + if (handle_dtype == DT_RESOURCE) { + ResourceHandle handle; + TF_RETURN_IF_ERROR(HandleFromInput(ctx, input_name, &handle)); + TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &lookup_table)); + *table = lookup_table->GetInitializableLookupTable(); + if (*table == nullptr) { + lookup_table->Unref(); + return errors::InvalidArgument("Table ", handle.container(), " ", + handle.name(), " is not initializable"); + } + } else { + string container; + string table_handle; + TF_RETURN_IF_ERROR( + GetTableHandle(input_name, ctx, &container, &table_handle)); + TF_RETURN_IF_ERROR(ctx->resource_manager()->Lookup(container, table_handle, + &lookup_table)); + *table = lookup_table->GetInitializableLookupTable(); + if (*table == nullptr) { + lookup_table->Unref(); + return errors::InvalidArgument("Table ", container, " ", table_handle, + " is not initializable"); + } } return Status::OK(); } diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index b34dd4ae90b..f82e9d1eb76 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -210,10 +210,29 @@ Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { return Status::OK(); } +Status ScalarAndTwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { + ShapeHandle handle; + DimensionHandle unused_handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + for (int i = 1; i < c->num_inputs(); ++i) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle)); + } + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->Scalar()); + } + return Status::OK(); +} + Status TwoElementOutput(InferenceContext* c) { c->set_output(0, c->Vector(2)); return Status::OK(); } + +Status ScalarOutput(InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); +} } // namespace REGISTER_OP("RandomShuffleQueue") @@ -1881,6 +1900,38 @@ values: Same shape as `keys`. Values found in the table, or `default_values` for missing keys. )doc"); +REGISTER_OP("LookupTableFindV2") + .Input("table_handle: resource") + .Input("keys: Tin") + .Input("default_value: Tout") + .Output("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + // Default value must be scalar or vector. + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused)); + c->set_output(0, c->UnknownShape()); + return Status::OK(); + }) + .Doc(R"doc( +Looks up keys in a table, outputs the corresponding values. + +The tensor `keys` must of the same type as the keys of the table. +The output `values` is of the type of the table values. + +The scalar `default_value` is the value output for keys not present in the +table. It must also be of the same type as the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Same shape as `keys`. Values found in the table, or `default_values` + for missing keys. +)doc"); + REGISTER_OP("LookupTableInsert") .Input("table_handle: Ref(string)") .Input("keys: Tin") @@ -1893,6 +1944,30 @@ REGISTER_OP("LookupTableInsert") DimensionHandle unused_dim; TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + // TODO(ebrevdo): Validate keys and values shape. + return Status::OK(); + }) + .Doc(R"doc( +Updates the table to associates keys with values. + +The tensor `keys` must be of the same type as the keys of the table. +The tensor `values` must be of the type of the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Values to associate with keys. +)doc"); + +REGISTER_OP("LookupTableInsertV2") + .Input("table_handle: resource") + .Input("keys: Tin") + .Input("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + // TODO: Validate keys and values shape. return Status::OK(); }) @@ -1918,6 +1993,17 @@ table_handle: Handle to the table. size: Scalar that contains number of elements in the table. )doc"); +REGISTER_OP("LookupTableSizeV2") + .Input("table_handle: resource") + .Output("size: int64") + .SetShapeFn(ScalarAndTwoElementVectorInputsAndScalarOutputs) + .Doc(R"doc( +Computes the number of elements in the given table. + +table_handle: Handle to the table. +size: Scalar that contains number of elements in the table. +)doc"); + REGISTER_OP("LookupTableExport") .Input("table_handle: Ref(string)") .Output("keys: Tkeys") @@ -1945,6 +2031,31 @@ keys: Vector of all keys present in the table. values: Tensor of all values in the table. Indexed in parallel with `keys`. )doc"); +REGISTER_OP("LookupTableExportV2") + .Input("table_handle: resource") + .Output("keys: Tkeys") + .Output("values: Tvalues") + .Attr("Tkeys: type") + .Attr("Tvalues: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + ShapeHandle values = c->UnknownShape(); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values)); + ShapeHandle keys = c->Vector(c->Dim(values, 0)); + c->set_output(0, keys); + c->set_output(1, values); + return Status::OK(); + }) + .Doc(R"doc( +Outputs all keys and values in the table. + +table_handle: Handle to the table. +keys: Vector of all keys present in the table. +values: Tensor of all values in the table. Indexed in parallel with `keys`. +)doc"); + REGISTER_OP("LookupTableImport") .Input("table_handle: Ref(string)") .Input("keys: Tin") @@ -1957,6 +2068,30 @@ REGISTER_OP("LookupTableImport") DimensionHandle unused_dim; TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + // TODO(ebrevdo): Validate keys and values shape. + return Status::OK(); + }) + .Doc(R"doc( +Replaces the contents of the table with the specified keys and values. + +The tensor `keys` must be of the same type as the keys of the table. +The tensor `values` must be of the type of the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Values to associate with keys. +)doc"); + +REGISTER_OP("LookupTableImportV2") + .Input("table_handle: resource") + .Input("keys: Tin") + .Input("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + // TODO: Validate keys and values shape. return Status::OK(); }) @@ -1998,6 +2133,33 @@ key_dtype: Type of the table keys. value_dtype: Type of the table values. )doc"); +REGISTER_OP("HashTableV2") + .Output("table_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetIsStateful() + .SetShapeFn(ScalarOutput) + .Doc(R"doc( +Creates a non-initialized hash table. + +This op creates a hash table, specifying the type of its keys and values. +Before using the table you will have to initialize it. After initialization the +table will be immutable. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +use_node_name_sharing: If true and shared_name is empty, the table is shared + using the node name. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + REGISTER_OP("MutableHashTable") .Output("table_handle: Ref(string)") .Attr("container: string = ''") @@ -2025,6 +2187,33 @@ key_dtype: Type of the table keys. value_dtype: Type of the table values. )doc"); +REGISTER_OP("MutableHashTableV2") + .Output("table_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetIsStateful() + .SetShapeFn(ScalarOutput) + .Doc(R"doc( +Creates an empty hash table. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a scalar. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +use_node_name_sharing: If true and shared_name is empty, the table is shared + using the node name. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + REGISTER_OP("MutableHashTableOfTensors") .Output("table_handle: Ref(string)") .Attr("container: string = ''") @@ -2051,6 +2240,32 @@ key_dtype: Type of the table keys. value_dtype: Type of the table values. )doc"); +REGISTER_OP("MutableHashTableOfTensorsV2") + .Output("table_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Attr("value_shape: shape = {}") + .SetIsStateful() + .SetShapeFn(ScalarOutput) + .Doc(R"doc( +Creates an empty hash table. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a vector. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + REGISTER_OP("MutableDenseHashTable") .Input("empty_key: key_dtype") .Output("table_handle: Ref(string)") @@ -2088,6 +2303,43 @@ max_load_factor: The maximum ratio between number of entries and number of buckets before growing the table. Must be between 0 and 1. )doc"); +REGISTER_OP("MutableDenseHashTableV2") + .Input("empty_key: key_dtype") + .Output("table_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Attr("value_shape: shape = {}") + .Attr("initial_num_buckets: int = 131072") // 2^17 + .Attr("max_load_factor: float = 0.8") + .SetIsStateful() + .SetShapeFn(ScalarOutput) + .Doc(R"doc( +Creates an empty hash table that uses tensors as the backing store. It uses +"open addressing" with quadratic reprobing to resolve collisions. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a scalar. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +empty_key: The key used to represent empty key buckets internally. Must not + be used in insert or lookup operations. +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +value_shape: The shape of each value. +initial_num_buckets: The initial number of hash table buckets. Must be a power + to 2. +max_load_factor: The maximum ratio between number of entries and number of + buckets before growing the table. Must be between 0 and 1. +)doc"); + REGISTER_OP("InitializeTable") .Input("table_handle: Ref(string)") .Input("keys: Tkey") @@ -2113,6 +2365,29 @@ keys: Keys of type Tkey. values: Values of type Tval. )doc"); +REGISTER_OP("InitializeTableV2") + .Input("table_handle: resource") + .Input("keys: Tkey") + .Input("values: Tval") + .Attr("Tkey: type") + .Attr("Tval: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + ShapeHandle keys; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); + TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); + return Status::OK(); + }) + .Doc(R"doc( +Table initializer that takes two tensors for keys and values respectively. + +table_handle: Handle to a table which will be initialized. +keys: Keys of type Tkey. +values: Values of type Tval. +)doc"); + REGISTER_OP("InitializeTableFromTextFile") .Input("table_handle: Ref(string)") .Input("filename: string") @@ -2152,6 +2427,43 @@ vocab_size: Number of elements of the file, use -1 if unknown. delimiter: Delimiter to separate fields in a line. )doc"); +REGISTER_OP("InitializeTableFromTextFileV2") + .Input("table_handle: resource") + .Input("filename: string") + .Attr("key_index: int >= -2") + .Attr("value_index: int >= -2") + .Attr("vocab_size: int >= -1 = -1") + .Attr("delimiter: string = '\t'") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); + return Status::OK(); + }) + .Doc(R"doc( +Initializes a table from a text file. + +It inserts one key-value pair into the table for each line of the file. +The key and value is extracted from the whole line content, elements from the +split line based on `delimiter` or the line number (starting from zero). +Where to extract the key and value from a line is specified by `key_index` and +`value_index`. + +- A value of -1 means use the line number(starting from zero), expects `int64`. +- A value of -2 means use the whole line content, expects `string`. +- A value >= 0 means use the index (starting at zero) of the split line based + on `delimiter`. + +table_handle: Handle to a table which will be initialized. +filename: Filename of a vocabulary text file. +key_index: Column index in a line to get the table `key` values from. +value_index: Column index that represents information of a line to get the table + `value` values from. +vocab_size: Number of elements of the file, use -1 if unknown. +delimiter: Delimiter to separate fields in a line. +)doc"); + REGISTER_OP("GetSessionHandle") .Input("value: T") .Output("handle: string") diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt index 9022e1453df..ec02ee3e039 100644 --- a/tensorflow/python/ops/hidden_ops.txt +++ b/tensorflow/python/ops/hidden_ops.txt @@ -63,16 +63,27 @@ GetSessionHandle GetSessionHandleV2 GetSessionTensor HashTable +HashTableV2 InitializeTable +InitializeTableV2 InitializeTableFromTextFile +InitializeTableFromTextFileV2 LookupTableExport +LookupTableExportV2 LookupTableFind +LookupTableFindV2 LookupTableImport +LookupTableImportV2 LookupTableInsert +LookupTableInsertV2 LookupTableSize +LookupTableSizeV2 MutableDenseHashTable +MutableDenseHashTableV2 MutableHashTable +MutableHashTableV2 MutableHashTableOfTensors +MutableHashTableOfTensorsV2 Mutex MutexAcquire MutexRelease From 55d232add8bf71122c9123b9186fa66c32aba03b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 1 May 2017 11:19:35 -0800 Subject: [PATCH 20/38] Update ops-related pbtxt files. Change: 154743864 --- .../core/ops/compat/ops_history.v1.pbtxt | 358 +++++++++++++++ tensorflow/core/ops/ops.pbtxt | 426 ++++++++++++++++++ 2 files changed, 784 insertions(+) diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 439b022036f..8856a0faaf4 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -8905,6 +8905,43 @@ op { } is_stateful: true } +op { + name: "HashTableV2" + output_arg { + name: "table_handle" + type: DT_RESOURCE + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "use_node_name_sharing" + type: "bool" + default_value { + b: false + } + } + attr { + name: "key_dtype" + type: "type" + } + attr { + name: "value_dtype" + type: "type" + } + is_stateful: true +} op { name: "HistogramSummary" input_arg { @@ -9336,6 +9373,70 @@ op { } } } +op { + name: "InitializeTableFromTextFileV2" + input_arg { + name: "table_handle" + type: DT_RESOURCE + } + input_arg { + name: "filename" + type: DT_STRING + } + attr { + name: "key_index" + type: "int" + has_minimum: true + minimum: -2 + } + attr { + name: "value_index" + type: "int" + has_minimum: true + minimum: -2 + } + attr { + name: "vocab_size" + type: "int" + default_value { + i: -1 + } + has_minimum: true + minimum: -1 + } + attr { + name: "delimiter" + type: "string" + default_value { + s: "\t" + } + } + is_stateful: true +} +op { + name: "InitializeTableV2" + input_arg { + name: "table_handle" + type: DT_RESOURCE + } + input_arg { + name: "keys" + type_attr: "Tkey" + } + input_arg { + name: "values" + type_attr: "Tval" + } + attr { + name: "Tkey" + type: "type" + } + attr { + name: "Tval" + type: "type" + } + is_stateful: true +} op { name: "Inv" input_arg { @@ -10182,6 +10283,30 @@ op { type: "type" } } +op { + name: "LookupTableExportV2" + input_arg { + name: "table_handle" + type: DT_RESOURCE + } + output_arg { + name: "keys" + type_attr: "Tkeys" + } + output_arg { + name: "values" + type_attr: "Tvalues" + } + attr { + name: "Tkeys" + type: "type" + } + attr { + name: "Tvalues" + type: "type" + } + is_stateful: true +} op { name: "LookupTableFind" input_arg { @@ -10210,6 +10335,34 @@ op { type: "type" } } +op { + name: "LookupTableFindV2" + input_arg { + name: "table_handle" + type: DT_RESOURCE + } + input_arg { + name: "keys" + type_attr: "Tin" + } + input_arg { + name: "default_value" + type_attr: "Tout" + } + output_arg { + name: "values" + type_attr: "Tout" + } + attr { + name: "Tin" + type: "type" + } + attr { + name: "Tout" + type: "type" + } + is_stateful: true +} op { name: "LookupTableImport" input_arg { @@ -10234,6 +10387,30 @@ op { type: "type" } } +op { + name: "LookupTableImportV2" + input_arg { + name: "table_handle" + type: DT_RESOURCE + } + input_arg { + name: "keys" + type_attr: "Tin" + } + input_arg { + name: "values" + type_attr: "Tout" + } + attr { + name: "Tin" + type: "type" + } + attr { + name: "Tout" + type: "type" + } + is_stateful: true +} op { name: "LookupTableInsert" input_arg { @@ -10258,6 +10435,30 @@ op { type: "type" } } +op { + name: "LookupTableInsertV2" + input_arg { + name: "table_handle" + type: DT_RESOURCE + } + input_arg { + name: "keys" + type_attr: "Tin" + } + input_arg { + name: "values" + type_attr: "Tout" + } + attr { + name: "Tin" + type: "type" + } + attr { + name: "Tout" + type: "type" + } + is_stateful: true +} op { name: "LookupTableSize" input_arg { @@ -10270,6 +10471,18 @@ op { type: DT_INT64 } } +op { + name: "LookupTableSizeV2" + input_arg { + name: "table_handle" + type: DT_RESOURCE + } + output_arg { + name: "size" + type: DT_INT64 + } + is_stateful: true +} op { name: "LoopCond" input_arg { @@ -12174,6 +12387,69 @@ op { } is_stateful: true } +op { + name: "MutableDenseHashTableV2" + input_arg { + name: "empty_key" + type_attr: "key_dtype" + } + output_arg { + name: "table_handle" + type: DT_RESOURCE + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "use_node_name_sharing" + type: "bool" + default_value { + b: false + } + } + attr { + name: "key_dtype" + type: "type" + } + attr { + name: "value_dtype" + type: "type" + } + attr { + name: "value_shape" + type: "shape" + default_value { + shape { + } + } + } + attr { + name: "initial_num_buckets" + type: "int" + default_value { + i: 131072 + } + } + attr { + name: "max_load_factor" + type: "float" + default_value { + f: 0.8 + } + } + is_stateful: true +} op { name: "MutableHashTable" output_arg { @@ -12258,6 +12534,88 @@ op { } is_stateful: true } +op { + name: "MutableHashTableOfTensorsV2" + output_arg { + name: "table_handle" + type: DT_RESOURCE + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "use_node_name_sharing" + type: "bool" + default_value { + b: false + } + } + attr { + name: "key_dtype" + type: "type" + } + attr { + name: "value_dtype" + type: "type" + } + attr { + name: "value_shape" + type: "shape" + default_value { + shape { + } + } + } + is_stateful: true +} +op { + name: "MutableHashTableV2" + output_arg { + name: "table_handle" + type: DT_RESOURCE + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "use_node_name_sharing" + type: "bool" + default_value { + b: false + } + } + attr { + name: "key_dtype" + type: "type" + } + attr { + name: "value_dtype" + type: "type" + } + is_stateful: true +} op { name: "Neg" input_arg { diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index cb313a6c278..f6eb0cd7a85 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -8745,6 +8745,51 @@ op { description: "This op creates a hash table, specifying the type of its keys and values.\nBefore using the table you will have to initialize it. After initialization the\ntable will be immutable." is_stateful: true } +op { + name: "HashTableV2" + output_arg { + name: "table_handle" + description: "Handle to a table." + type: DT_RESOURCE + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + description: "If non-empty, this table is placed in the given container.\nOtherwise, a default container is used." + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + description: "If non-empty, this table is shared under the given name across\nmultiple sessions." + } + attr { + name: "use_node_name_sharing" + type: "bool" + default_value { + b: false + } + description: "If true and shared_name is empty, the table is shared\nusing the node name." + } + attr { + name: "key_dtype" + type: "type" + description: "Type of the table keys." + } + attr { + name: "value_dtype" + type: "type" + description: "Type of the table values." + } + summary: "Creates a non-initialized hash table." + description: "This op creates a hash table, specifying the type of its keys and values.\nBefore using the table you will have to initialize it. After initialization the\ntable will be immutable." + is_stateful: true +} op { name: "HistogramSummary" input_arg { @@ -9255,6 +9300,82 @@ op { summary: "Initializes a table from a text file." description: "It inserts one key-value pair into the table for each line of the file.\nThe key and value is extracted from the whole line content, elements from the\nsplit line based on `delimiter` or the line number (starting from zero).\nWhere to extract the key and value from a line is specified by `key_index` and\n`value_index`.\n\n- A value of -1 means use the line number(starting from zero), expects `int64`.\n- A value of -2 means use the whole line content, expects `string`.\n- A value >= 0 means use the index (starting at zero) of the split line based\n on `delimiter`." } +op { + name: "InitializeTableFromTextFileV2" + input_arg { + name: "table_handle" + description: "Handle to a table which will be initialized." + type: DT_RESOURCE + } + input_arg { + name: "filename" + description: "Filename of a vocabulary text file." + type: DT_STRING + } + attr { + name: "key_index" + type: "int" + description: "Column index in a line to get the table `key` values from." + has_minimum: true + minimum: -2 + } + attr { + name: "value_index" + type: "int" + description: "Column index that represents information of a line to get the table\n`value` values from." + has_minimum: true + minimum: -2 + } + attr { + name: "vocab_size" + type: "int" + default_value { + i: -1 + } + description: "Number of elements of the file, use -1 if unknown." + has_minimum: true + minimum: -1 + } + attr { + name: "delimiter" + type: "string" + default_value { + s: "\t" + } + description: "Delimiter to separate fields in a line." + } + summary: "Initializes a table from a text file." + description: "It inserts one key-value pair into the table for each line of the file.\nThe key and value is extracted from the whole line content, elements from the\nsplit line based on `delimiter` or the line number (starting from zero).\nWhere to extract the key and value from a line is specified by `key_index` and\n`value_index`.\n\n- A value of -1 means use the line number(starting from zero), expects `int64`.\n- A value of -2 means use the whole line content, expects `string`.\n- A value >= 0 means use the index (starting at zero) of the split line based\n on `delimiter`." + is_stateful: true +} +op { + name: "InitializeTableV2" + input_arg { + name: "table_handle" + description: "Handle to a table which will be initialized." + type: DT_RESOURCE + } + input_arg { + name: "keys" + description: "Keys of type Tkey." + type_attr: "Tkey" + } + input_arg { + name: "values" + description: "Values of type Tval." + type_attr: "Tval" + } + attr { + name: "Tkey" + type: "type" + } + attr { + name: "Tval" + type: "type" + } + summary: "Table initializer that takes two tensors for keys and values respectively." + is_stateful: true +} op { name: "Inv" input_arg { @@ -10067,6 +10188,34 @@ op { } summary: "Outputs all keys and values in the table." } +op { + name: "LookupTableExportV2" + input_arg { + name: "table_handle" + description: "Handle to the table." + type: DT_RESOURCE + } + output_arg { + name: "keys" + description: "Vector of all keys present in the table." + type_attr: "Tkeys" + } + output_arg { + name: "values" + description: "Tensor of all values in the table. Indexed in parallel with `keys`." + type_attr: "Tvalues" + } + attr { + name: "Tkeys" + type: "type" + } + attr { + name: "Tvalues" + type: "type" + } + summary: "Outputs all keys and values in the table." + is_stateful: true +} op { name: "LookupTableFind" input_arg { @@ -10100,6 +10249,39 @@ op { summary: "Looks up keys in a table, outputs the corresponding values." description: "The tensor `keys` must of the same type as the keys of the table.\nThe output `values` is of the type of the table values.\n\nThe scalar `default_value` is the value output for keys not present in the\ntable. It must also be of the same type as the table values." } +op { + name: "LookupTableFindV2" + input_arg { + name: "table_handle" + description: "Handle to the table." + type: DT_RESOURCE + } + input_arg { + name: "keys" + description: "Any shape. Keys to look up." + type_attr: "Tin" + } + input_arg { + name: "default_value" + type_attr: "Tout" + } + output_arg { + name: "values" + description: "Same shape as `keys`. Values found in the table, or `default_values`\nfor missing keys." + type_attr: "Tout" + } + attr { + name: "Tin" + type: "type" + } + attr { + name: "Tout" + type: "type" + } + summary: "Looks up keys in a table, outputs the corresponding values." + description: "The tensor `keys` must of the same type as the keys of the table.\nThe output `values` is of the type of the table values.\n\nThe scalar `default_value` is the value output for keys not present in the\ntable. It must also be of the same type as the table values." + is_stateful: true +} op { name: "LookupTableImport" input_arg { @@ -10129,6 +10311,35 @@ op { summary: "Replaces the contents of the table with the specified keys and values." description: "The tensor `keys` must be of the same type as the keys of the table.\nThe tensor `values` must be of the type of the table values." } +op { + name: "LookupTableImportV2" + input_arg { + name: "table_handle" + description: "Handle to the table." + type: DT_RESOURCE + } + input_arg { + name: "keys" + description: "Any shape. Keys to look up." + type_attr: "Tin" + } + input_arg { + name: "values" + description: "Values to associate with keys." + type_attr: "Tout" + } + attr { + name: "Tin" + type: "type" + } + attr { + name: "Tout" + type: "type" + } + summary: "Replaces the contents of the table with the specified keys and values." + description: "The tensor `keys` must be of the same type as the keys of the table.\nThe tensor `values` must be of the type of the table values." + is_stateful: true +} op { name: "LookupTableInsert" input_arg { @@ -10158,6 +10369,35 @@ op { summary: "Updates the table to associates keys with values." description: "The tensor `keys` must be of the same type as the keys of the table.\nThe tensor `values` must be of the type of the table values." } +op { + name: "LookupTableInsertV2" + input_arg { + name: "table_handle" + description: "Handle to the table." + type: DT_RESOURCE + } + input_arg { + name: "keys" + description: "Any shape. Keys to look up." + type_attr: "Tin" + } + input_arg { + name: "values" + description: "Values to associate with keys." + type_attr: "Tout" + } + attr { + name: "Tin" + type: "type" + } + attr { + name: "Tout" + type: "type" + } + summary: "Updates the table to associates keys with values." + description: "The tensor `keys` must be of the same type as the keys of the table.\nThe tensor `values` must be of the type of the table values." + is_stateful: true +} op { name: "LookupTableSize" input_arg { @@ -10173,6 +10413,21 @@ op { } summary: "Computes the number of elements in the given table." } +op { + name: "LookupTableSizeV2" + input_arg { + name: "table_handle" + description: "Handle to the table." + type: DT_RESOURCE + } + output_arg { + name: "size" + description: "Scalar that contains number of elements in the table." + type: DT_INT64 + } + summary: "Computes the number of elements in the given table." + is_stateful: true +} op { name: "LoopCond" input_arg { @@ -11877,6 +12132,80 @@ op { description: "\"open addressing\" with quadratic reprobing to resolve collisions.\n\nThis op creates a mutable hash table, specifying the type of its keys and\nvalues. Each value must be a scalar. Data can be inserted into the table using\nthe insert operations. It does not support the initialization operation." is_stateful: true } +op { + name: "MutableDenseHashTableV2" + input_arg { + name: "empty_key" + description: "The key used to represent empty key buckets internally. Must not\nbe used in insert or lookup operations." + type_attr: "key_dtype" + } + output_arg { + name: "table_handle" + description: "Handle to a table." + type: DT_RESOURCE + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + description: "If non-empty, this table is placed in the given container.\nOtherwise, a default container is used." + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + description: "If non-empty, this table is shared under the given name across\nmultiple sessions." + } + attr { + name: "use_node_name_sharing" + type: "bool" + default_value { + b: false + } + } + attr { + name: "key_dtype" + type: "type" + description: "Type of the table keys." + } + attr { + name: "value_dtype" + type: "type" + description: "Type of the table values." + } + attr { + name: "value_shape" + type: "shape" + default_value { + shape { + } + } + description: "The shape of each value." + } + attr { + name: "initial_num_buckets" + type: "int" + default_value { + i: 131072 + } + description: "The initial number of hash table buckets. Must be a power\nto 2." + } + attr { + name: "max_load_factor" + type: "float" + default_value { + f: 0.8 + } + description: "The maximum ratio between number of entries and number of\nbuckets before growing the table. Must be between 0 and 1." + } + summary: "Creates an empty hash table that uses tensors as the backing store. It uses" + description: "\"open addressing\" with quadratic reprobing to resolve collisions.\n\nThis op creates a mutable hash table, specifying the type of its keys and\nvalues. Each value must be a scalar. Data can be inserted into the table using\nthe insert operations. It does not support the initialization operation." + is_stateful: true +} op { name: "MutableHashTable" output_arg { @@ -11976,6 +12305,103 @@ op { description: "This op creates a mutable hash table, specifying the type of its keys and\nvalues. Each value must be a vector. Data can be inserted into the table using\nthe insert operations. It does not support the initialization operation." is_stateful: true } +op { + name: "MutableHashTableOfTensorsV2" + output_arg { + name: "table_handle" + description: "Handle to a table." + type: DT_RESOURCE + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + description: "If non-empty, this table is placed in the given container.\nOtherwise, a default container is used." + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + description: "If non-empty, this table is shared under the given name across\nmultiple sessions." + } + attr { + name: "use_node_name_sharing" + type: "bool" + default_value { + b: false + } + } + attr { + name: "key_dtype" + type: "type" + description: "Type of the table keys." + } + attr { + name: "value_dtype" + type: "type" + description: "Type of the table values." + } + attr { + name: "value_shape" + type: "shape" + default_value { + shape { + } + } + } + summary: "Creates an empty hash table." + description: "This op creates a mutable hash table, specifying the type of its keys and\nvalues. Each value must be a vector. Data can be inserted into the table using\nthe insert operations. It does not support the initialization operation." + is_stateful: true +} +op { + name: "MutableHashTableV2" + output_arg { + name: "table_handle" + description: "Handle to a table." + type: DT_RESOURCE + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + description: "If non-empty, this table is placed in the given container.\nOtherwise, a default container is used." + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + description: "If non-empty, this table is shared under the given name across\nmultiple sessions." + } + attr { + name: "use_node_name_sharing" + type: "bool" + default_value { + b: false + } + description: "If true and shared_name is empty, the table is shared\nusing the node name." + } + attr { + name: "key_dtype" + type: "type" + description: "Type of the table keys." + } + attr { + name: "value_dtype" + type: "type" + description: "Type of the table values." + } + summary: "Creates an empty hash table." + description: "This op creates a mutable hash table, specifying the type of its keys and\nvalues. Each value must be a scalar. Data can be inserted into the table using\nthe insert operations. It does not support the initialization operation." + is_stateful: true +} op { name: "Neg" input_arg { From a1efebe6289c7876b343162076fa968532dbbfbe Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 1 May 2017 11:26:08 -0800 Subject: [PATCH 21/38] Go: Update generated wrapper functions for TensorFlow ops. Change: 154744701 --- tensorflow/go/op/wrappers.go | 501 +++++++++++++++++++++++++++++++++++ 1 file changed, 501 insertions(+) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 001e5f9feb2..4caa75c930f 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -3522,6 +3522,256 @@ func Stage(scope *Scope, values []tf.Output, optional ...StageAttr) (o *tf.Opera return scope.AddOperation(opspec) } +// Table initializer that takes two tensors for keys and values respectively. +// +// Arguments: +// table_handle: Handle to a table which will be initialized. +// keys: Keys of type Tkey. +// values: Values of type Tval. +// +// Returns the created operation. +func InitializeTableV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "InitializeTableV2", + Input: []tf.Input{ + table_handle, keys, values, + }, + } + return scope.AddOperation(opspec) +} + +// MutableHashTableV2Attr is an optional argument to MutableHashTableV2. +type MutableHashTableV2Attr func(optionalAttr) + +// MutableHashTableV2Container sets the optional container attribute to value. +// +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func MutableHashTableV2Container(value string) MutableHashTableV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MutableHashTableV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func MutableHashTableV2SharedName(value string) MutableHashTableV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// MutableHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// +// value: If true and shared_name is empty, the table is shared +// using the node name. +// If not specified, defaults to false +func MutableHashTableV2UseNodeNameSharing(value bool) MutableHashTableV2Attr { + return func(m optionalAttr) { + m["use_node_name_sharing"] = value + } +} + +// Creates an empty hash table. +// +// This op creates a mutable hash table, specifying the type of its keys and +// values. Each value must be a scalar. Data can be inserted into the table using +// the insert operations. It does not support the initialization operation. +// +// Arguments: +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. +// +// Returns Handle to a table. +func MutableHashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableV2Attr) (table_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MutableHashTableV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// HashTableV2Attr is an optional argument to HashTableV2. +type HashTableV2Attr func(optionalAttr) + +// HashTableV2Container sets the optional container attribute to value. +// +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func HashTableV2Container(value string) HashTableV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// HashTableV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func HashTableV2SharedName(value string) HashTableV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// HashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// +// value: If true and shared_name is empty, the table is shared +// using the node name. +// If not specified, defaults to false +func HashTableV2UseNodeNameSharing(value bool) HashTableV2Attr { + return func(m optionalAttr) { + m["use_node_name_sharing"] = value + } +} + +// Creates a non-initialized hash table. +// +// This op creates a hash table, specifying the type of its keys and values. +// Before using the table you will have to initialize it. After initialization the +// table will be immutable. +// +// Arguments: +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. +// +// Returns Handle to a table. +func HashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...HashTableV2Attr) (table_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "HashTableV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Replaces the contents of the table with the specified keys and values. +// +// The tensor `keys` must be of the same type as the keys of the table. +// The tensor `values` must be of the type of the table values. +// +// Arguments: +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. +// values: Values to associate with keys. +// +// Returns the created operation. +func LookupTableImportV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LookupTableImportV2", + Input: []tf.Input{ + table_handle, keys, values, + }, + } + return scope.AddOperation(opspec) +} + +// Outputs all keys and values in the table. +// +// Arguments: +// table_handle: Handle to the table. +// +// +// +// Returns Vector of all keys present in the table.Tensor of all values in the table. Indexed in parallel with `keys`. +func LookupTableExportV2(scope *Scope, table_handle tf.Output, Tkeys tf.DataType, Tvalues tf.DataType) (keys tf.Output, values tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"Tkeys": Tkeys, "Tvalues": Tvalues} + opspec := tf.OpSpec{ + Type: "LookupTableExportV2", + Input: []tf.Input{ + table_handle, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Updates the table to associates keys with values. +// +// The tensor `keys` must be of the same type as the keys of the table. +// The tensor `values` must be of the type of the table values. +// +// Arguments: +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. +// values: Values to associate with keys. +// +// Returns the created operation. +func LookupTableInsertV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LookupTableInsertV2", + Input: []tf.Input{ + table_handle, keys, values, + }, + } + return scope.AddOperation(opspec) +} + +// Looks up keys in a table, outputs the corresponding values. +// +// The tensor `keys` must of the same type as the keys of the table. +// The output `values` is of the type of the table values. +// +// The scalar `default_value` is the value output for keys not present in the +// table. It must also be of the same type as the table values. +// +// Arguments: +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. +// +// +// Returns Same shape as `keys`. Values found in the table, or `default_values` +// for missing keys. +func LookupTableFindV2(scope *Scope, table_handle tf.Output, keys tf.Output, default_value tf.Output) (values tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LookupTableFindV2", + Input: []tf.Input{ + table_handle, keys, default_value, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // FakeQuantWithMinMaxArgsAttr is an optional argument to FakeQuantWithMinMaxArgs. type FakeQuantWithMinMaxArgsAttr func(optionalAttr) @@ -10019,6 +10269,70 @@ func QuantizedRelu(scope *Scope, features tf.Output, min_features tf.Output, max return op.Output(0), op.Output(1), op.Output(2) } +// InitializeTableFromTextFileV2Attr is an optional argument to InitializeTableFromTextFileV2. +type InitializeTableFromTextFileV2Attr func(optionalAttr) + +// InitializeTableFromTextFileV2VocabSize sets the optional vocab_size attribute to value. +// +// value: Number of elements of the file, use -1 if unknown. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func InitializeTableFromTextFileV2VocabSize(value int64) InitializeTableFromTextFileV2Attr { + return func(m optionalAttr) { + m["vocab_size"] = value + } +} + +// InitializeTableFromTextFileV2Delimiter sets the optional delimiter attribute to value. +// +// value: Delimiter to separate fields in a line. +// If not specified, defaults to "\t" +func InitializeTableFromTextFileV2Delimiter(value string) InitializeTableFromTextFileV2Attr { + return func(m optionalAttr) { + m["delimiter"] = value + } +} + +// Initializes a table from a text file. +// +// It inserts one key-value pair into the table for each line of the file. +// The key and value is extracted from the whole line content, elements from the +// split line based on `delimiter` or the line number (starting from zero). +// Where to extract the key and value from a line is specified by `key_index` and +// `value_index`. +// +// - A value of -1 means use the line number(starting from zero), expects `int64`. +// - A value of -2 means use the whole line content, expects `string`. +// - A value >= 0 means use the index (starting at zero) of the split line based +// on `delimiter`. +// +// Arguments: +// table_handle: Handle to a table which will be initialized. +// filename: Filename of a vocabulary text file. +// key_index: Column index in a line to get the table `key` values from. +// value_index: Column index that represents information of a line to get the table +// `value` values from. +// +// Returns the created operation. +func InitializeTableFromTextFileV2(scope *Scope, table_handle tf.Output, filename tf.Output, key_index int64, value_index int64, optional ...InitializeTableFromTextFileV2Attr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"key_index": key_index, "value_index": value_index} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "InitializeTableFromTextFileV2", + Input: []tf.Input{ + table_handle, filename, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + // ResourceSparseApplyProximalGradientDescentAttr is an optional argument to ResourceSparseApplyProximalGradientDescent. type ResourceSparseApplyProximalGradientDescentAttr func(optionalAttr) @@ -10875,6 +11189,75 @@ func IFFT2D(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } +// MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2. +type MutableHashTableOfTensorsV2Attr func(optionalAttr) + +// MutableHashTableOfTensorsV2Container sets the optional container attribute to value. +// +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func MutableHashTableOfTensorsV2Container(value string) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MutableHashTableOfTensorsV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func MutableHashTableOfTensorsV2SharedName(value string) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// MutableHashTableOfTensorsV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// If not specified, defaults to false +func MutableHashTableOfTensorsV2UseNodeNameSharing(value bool) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["use_node_name_sharing"] = value + } +} + +// MutableHashTableOfTensorsV2ValueShape sets the optional value_shape attribute to value. +// If not specified, defaults to <> +func MutableHashTableOfTensorsV2ValueShape(value tf.Shape) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["value_shape"] = value + } +} + +// Creates an empty hash table. +// +// This op creates a mutable hash table, specifying the type of its keys and +// values. Each value must be a vector. Data can be inserted into the table using +// the insert operations. It does not support the initialization operation. +// +// Arguments: +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. +// +// Returns Handle to a table. +func MutableHashTableOfTensorsV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableOfTensorsV2Attr) (table_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MutableHashTableOfTensorsV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // ResourceApplyProximalAdagradAttr is an optional argument to ResourceApplyProximalAdagrad. type ResourceApplyProximalAdagradAttr func(optionalAttr) @@ -13270,6 +13653,104 @@ func ReaderReadV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Output) return op.Output(0), op.Output(1) } +// MutableDenseHashTableV2Attr is an optional argument to MutableDenseHashTableV2. +type MutableDenseHashTableV2Attr func(optionalAttr) + +// MutableDenseHashTableV2Container sets the optional container attribute to value. +// +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func MutableDenseHashTableV2Container(value string) MutableDenseHashTableV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MutableDenseHashTableV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func MutableDenseHashTableV2SharedName(value string) MutableDenseHashTableV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// MutableDenseHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// If not specified, defaults to false +func MutableDenseHashTableV2UseNodeNameSharing(value bool) MutableDenseHashTableV2Attr { + return func(m optionalAttr) { + m["use_node_name_sharing"] = value + } +} + +// MutableDenseHashTableV2ValueShape sets the optional value_shape attribute to value. +// +// value: The shape of each value. +// If not specified, defaults to <> +func MutableDenseHashTableV2ValueShape(value tf.Shape) MutableDenseHashTableV2Attr { + return func(m optionalAttr) { + m["value_shape"] = value + } +} + +// MutableDenseHashTableV2InitialNumBuckets sets the optional initial_num_buckets attribute to value. +// +// value: The initial number of hash table buckets. Must be a power +// to 2. +// If not specified, defaults to 131072 +func MutableDenseHashTableV2InitialNumBuckets(value int64) MutableDenseHashTableV2Attr { + return func(m optionalAttr) { + m["initial_num_buckets"] = value + } +} + +// MutableDenseHashTableV2MaxLoadFactor sets the optional max_load_factor attribute to value. +// +// value: The maximum ratio between number of entries and number of +// buckets before growing the table. Must be between 0 and 1. +// If not specified, defaults to 0.8 +func MutableDenseHashTableV2MaxLoadFactor(value float32) MutableDenseHashTableV2Attr { + return func(m optionalAttr) { + m["max_load_factor"] = value + } +} + +// Creates an empty hash table that uses tensors as the backing store. It uses +// +// "open addressing" with quadratic reprobing to resolve collisions. +// +// This op creates a mutable hash table, specifying the type of its keys and +// values. Each value must be a scalar. Data can be inserted into the table using +// the insert operations. It does not support the initialization operation. +// +// Arguments: +// empty_key: The key used to represent empty key buckets internally. Must not +// be used in insert or lookup operations. +// value_dtype: Type of the table values. +// +// Returns Handle to a table. +func MutableDenseHashTableV2(scope *Scope, empty_key tf.Output, value_dtype tf.DataType, optional ...MutableDenseHashTableV2Attr) (table_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"value_dtype": value_dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MutableDenseHashTableV2", + Input: []tf.Input{ + empty_key, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // LRNAttr is an optional argument to LRN. type LRNAttr func(optionalAttr) @@ -20554,6 +21035,26 @@ func SoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.O return op.Output(0), op.Output(1) } +// Computes the number of elements in the given table. +// +// Arguments: +// table_handle: Handle to the table. +// +// Returns Scalar that contains number of elements in the table. +func LookupTableSizeV2(scope *Scope, table_handle tf.Output) (size tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LookupTableSizeV2", + Input: []tf.Input{ + table_handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // ResizeBilinearGradAttr is an optional argument to ResizeBilinearGrad. type ResizeBilinearGradAttr func(optionalAttr) From bcf7e53d183a97f7139e3e1343eed363b69c2877 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 1 May 2017 11:37:14 -0800 Subject: [PATCH 22/38] Small fixes to op generation. Change: 154745991 --- tensorflow/python/framework/op_def_library.py | 2 +- tensorflow/python/framework/python_op_gen.cc | 10 +++++----- tensorflow/python/framework/python_op_gen.h | 7 ++++--- tensorflow/python/framework/python_op_gen.i | 2 +- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py index 2c39f5b0e37..662c2c679c8 100644 --- a/tensorflow/python/framework/op_def_library.py +++ b/tensorflow/python/framework/op_def_library.py @@ -328,7 +328,7 @@ class OpDefLibrary(object): # Need to flatten all the arguments into a list. # pylint: disable=protected-access g = ops._get_graph_from_inputs(_Flatten(keywords.values())) - # pyline: enable=protected-access + # pylint: enable=protected-access except AssertionError as e: raise RuntimeError( "Cannot determine graph for Op '%s' due to: %s" diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index 64be2c70a11..d6fd750f6fd 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -629,10 +629,10 @@ static string GetPythonOp(const OpDef& op_def, bool is_hidden, string op_name) { } void GenerateLowerCaseOpName(const string& str, string* result) { - char joiner = '_'; - int last_index = str.size() - 1; + const char joiner = '_'; + const int last_index = str.size() - 1; for (int i = 0; i <= last_index; ++i) { - char c = str[i]; + const char c = str[i]; // Emit a joiner only if a previous-lower-to-now-upper or a // now-upper-to-next-lower transition happens. if (isupper(c) && (i > 0)) { @@ -731,8 +731,8 @@ void PrintPythonOps(const OpList& ops, const std::vector<string>& hidden_ops, printf("%s", GetPythonOps(ops, hidden_ops, require_shapes).c_str()); } -string GetPythonWrappers(const char* op_wrapper_buf, size_t op_wrapper_len) { - string op_list_str(op_wrapper_buf, op_wrapper_len); +string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) { + string op_list_str(op_list_buf, op_list_len); OpList ops; ops.ParseFromString(op_list_str); return GetPythonOps(ops, {}, false); diff --git a/tensorflow/python/framework/python_op_gen.h b/tensorflow/python/framework/python_op_gen.h index 424244fcc55..e51eacbc17a 100644 --- a/tensorflow/python/framework/python_op_gen.h +++ b/tensorflow/python/framework/python_op_gen.h @@ -33,9 +33,10 @@ string GetPythonOps(const OpList& ops, const std::vector<string>& hidden_ops, bool require_shapes); // Get the python wrappers for a list of ops in a OpList. -// buf should be a pointer to a buffer containing the binary encoded OpList -// proto, and len should be the length of that buffer. -string GetPythonWrappers(const char* op_wrapper_buf, size_t op_wrapper_len); +// `op_list_buf` should be a pointer to a buffer containing +// the binary encoded OpList proto, and `op_list_len` should be the +// length of that buffer. +string GetPythonWrappers(const char* op_list_buf, size_t op_list_len); } // namespace tensorflow diff --git a/tensorflow/python/framework/python_op_gen.i b/tensorflow/python/framework/python_op_gen.i index 740eff4ecdb..26ec4e8e66b 100644 --- a/tensorflow/python/framework/python_op_gen.i +++ b/tensorflow/python/framework/python_op_gen.i @@ -25,7 +25,7 @@ limitations under the License. // going from python bytes to const char* tries to decode the // contents from utf-8 to unicode for Python version >= 3, but // we want the bytes to be uninterpreted. -%typemap(in) (const char* op_wrapper_buf, size_t op_wrapper_len) { +%typemap(in) (const char* op_list_buf, size_t op_list_len) { char* c_string; Py_ssize_t py_size; if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) { From bb17b9665c189d1349d783219306100204ef2352 Mon Sep 17 00:00:00 2001 From: Zakaria Haque <zakaria@google.com> Date: Mon, 1 May 2017 11:37:53 -0800 Subject: [PATCH 23/38] Sets tag explicitly for training loss so the training and eval losses will show up in the same graph in tensorboard. This was inadvertently changed in cl/153426348 Change: 154746092 --- .../learn/python/learn/estimators/head.py | 5 +- .../python/learn/estimators/head_test.py | 80 +++++++++---------- 2 files changed, 44 insertions(+), 41 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 12af78398b2..452f8a901ee 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -37,6 +37,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import sparse_ops @@ -613,7 +614,9 @@ def _create_model_fn_ops(features, if (mode != model_fn.ModeKeys.INFER) and (labels is not None): weight_tensor = _weight_tensor(features, weight_column_name) loss, weighted_average_loss = loss_fn(labels, logits, weight_tensor) - summary.scalar( + # Uses the deprecated API to set the tag explicitly. + # Without it, trianing and eval losses will show up in different graphs. + logging_ops.scalar_summary( _summary_key(head_name, mkey.LOSS), weighted_average_loss) if mode == model_fn.ModeKeys.TRAIN: diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index e81b15a1725..442530cb5ee 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -123,7 +123,7 @@ class PoissonHeadTest(test.TestCase): train_op_fn=head_lib.no_op_train_fn, logits=logits) self._assert_output_alternatives(model_fn_ops) - _assert_summary_tags(self, ["regression_head/loss"]) + _assert_summary_tags(self, ["loss"]) _assert_no_variables(self) loss = self._log_poisson_loss(logits, labels) _assert_metrics(self, loss, {"loss": loss}, model_fn_ops) @@ -149,7 +149,7 @@ class RegressionHeadTest(test.TestCase): train_op_fn=head_lib.no_op_train_fn, logits=((1.,), (1.,), (3.,))) self._assert_output_alternatives(model_fn_ops) - _assert_summary_tags(self, ["regression_head/loss"]) + _assert_summary_tags(self, ["loss"]) _assert_no_variables(self) _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) @@ -179,7 +179,7 @@ class RegressionHeadTest(test.TestCase): _assert_variables( self, expected_global=w, expected_model=w, expected_trainable=w) variables.global_variables_initializer().run() - _assert_summary_tags(self, ["regression_head/loss"]) + _assert_summary_tags(self, ["loss"]) _assert_metrics(self, 2. / 3, {"loss": 2. / 3}, model_fn_ops) def testRegressionWithLogitsAndLogitsInput(self): @@ -207,7 +207,7 @@ class RegressionHeadTest(test.TestCase): self._assert_output_alternatives(model_fn_ops) self.assertIsNone(model_fn_ops.train_op) _assert_no_variables(self) - _assert_summary_tags(self, ["regression_head/loss"]) + _assert_summary_tags(self, ["loss"]) _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) def testRegressionWithLabelName(self): @@ -222,7 +222,7 @@ class RegressionHeadTest(test.TestCase): logits=((1.,), (1.,), (3.,))) self._assert_output_alternatives(model_fn_ops) _assert_no_variables(self) - _assert_summary_tags(self, ["regression_head/loss"]) + _assert_summary_tags(self, ["loss"]) _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) def testRegressionWithWeights(self): @@ -237,7 +237,7 @@ class RegressionHeadTest(test.TestCase): logits=((1.,), (1.,), (3.,))) self._assert_output_alternatives(model_fn_ops) _assert_no_variables(self) - _assert_summary_tags(self, ["regression_head/loss"]) + _assert_summary_tags(self, ["loss"]) _assert_metrics(self, 2. / len(weights), {"loss": 2. / np.sum(weights)}, model_fn_ops) @@ -260,7 +260,7 @@ class RegressionHeadTest(test.TestCase): expected_trainable=("regression_head/centered_bias_weight:0",)) variables.global_variables_initializer().run() _assert_summary_tags(self, [ - "regression_head/loss", + "loss", "regression_head/centered_bias/bias_0" ]) _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) @@ -331,7 +331,7 @@ class MultiLabelHeadTest(test.TestCase): logits=self._logits) self._assert_output_alternatives(model_fn_ops) _assert_no_variables(self) - _assert_summary_tags(self, ["multi_label_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = .89985204 _assert_metrics(self, expected_loss, self._expected_eval_metrics(expected_loss), model_fn_ops) @@ -348,7 +348,7 @@ class MultiLabelHeadTest(test.TestCase): train_op_fn=head_lib.no_op_train_fn, logits=logits) self._assert_output_alternatives(model_fn_ops) _assert_no_variables(self) - _assert_summary_tags(self, ["multi_label_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = 1.00320443 _assert_metrics(self, expected_loss, { "accuracy": 0., @@ -388,7 +388,7 @@ class MultiLabelHeadTest(test.TestCase): _assert_variables( self, expected_global=w, expected_model=w, expected_trainable=w) variables.global_variables_initializer().run() - _assert_summary_tags(self, ["multi_label_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = .69314718 _assert_metrics(self, expected_loss, { "accuracy": 2. / 3, @@ -433,7 +433,7 @@ class MultiLabelHeadTest(test.TestCase): self._assert_output_alternatives(model_fn_ops) self.assertIsNone(model_fn_ops.train_op) _assert_no_variables(self) - _assert_summary_tags(self, ["multi_label_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = .89985204 _assert_metrics(self, expected_loss, self._expected_eval_metrics(expected_loss), model_fn_ops) @@ -452,7 +452,7 @@ class MultiLabelHeadTest(test.TestCase): self._assert_output_alternatives(model_fn_ops) self.assertIsNone(model_fn_ops.train_op) _assert_no_variables(self) - _assert_summary_tags(self, ["multi_label_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = 1.377779 expected_eval_metrics = { "accuracy": 1. / 3, @@ -520,7 +520,7 @@ class MultiLabelHeadTest(test.TestCase): head_lib.no_op_train_fn, logits=self._logits) self._assert_output_alternatives(model_fn_ops) _assert_no_variables(self) - _assert_summary_tags(self, ["multi_label_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = .89985204 _assert_metrics(self, expected_loss, self._expected_eval_metrics(expected_loss), model_fn_ops) @@ -540,7 +540,7 @@ class MultiLabelHeadTest(test.TestCase): logits=self._logits) self._assert_output_alternatives(model_fn_ops) _assert_no_variables(self) - _assert_summary_tags(self, ["multi_label_head/loss"]) + _assert_summary_tags(self, ["loss"]) _assert_metrics(self, .089985214, self._expected_eval_metrics(.89985214), model_fn_ops) @@ -559,7 +559,7 @@ class MultiLabelHeadTest(test.TestCase): logits=self._logits) self._assert_output_alternatives(model_fn_ops) _assert_no_variables(self) - _assert_summary_tags(self, ["multi_label_head/loss"]) + _assert_summary_tags(self, ["loss"]) _assert_metrics(self, .089985214, self._expected_eval_metrics(.89985214), model_fn_ops) @@ -579,7 +579,7 @@ class MultiLabelHeadTest(test.TestCase): logits=self._logits) self._assert_output_alternatives(model_fn_ops) _assert_no_variables(self) - _assert_summary_tags(self, ["multi_label_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = .089985214 _assert_metrics(self, expected_loss, self._expected_eval_metrics(expected_loss), model_fn_ops) @@ -604,7 +604,7 @@ class MultiLabelHeadTest(test.TestCase): expected_trainable=("multi_label_head/centered_bias_weight:0",)) variables.global_variables_initializer().run() _assert_summary_tags(self, ( - "multi_label_head/loss", + "loss", "multi_label_head/centered_bias/bias_0", "multi_label_head/centered_bias/bias_1", "multi_label_head/centered_bias/bias_2" @@ -629,7 +629,7 @@ class MultiLabelHeadTest(test.TestCase): train_op_fn=head_lib.no_op_train_fn, logits=self._logits) _assert_no_variables(self) - _assert_summary_tags(self, ["multi_label_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = .89985204 _assert_metrics(self, expected_loss, self._expected_eval_metrics(expected_loss), model_fn_ops) @@ -695,7 +695,7 @@ class BinaryClassificationHeadTest(test.TestCase): logits=self._logits) self._assert_output_alternatives(model_fn_ops) _assert_no_variables(self) - _assert_summary_tags(self, ["binary_logistic_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = .81326175 _assert_metrics(self, expected_loss, self._expected_eval_metrics(expected_loss), model_fn_ops) @@ -723,7 +723,7 @@ class BinaryClassificationHeadTest(test.TestCase): _assert_variables( self, expected_global=w, expected_model=w, expected_trainable=w) variables.global_variables_initializer().run() - _assert_summary_tags(self, ["binary_logistic_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = .69314718 label_mean = np.mean(self._labels) _assert_metrics(self, expected_loss, { @@ -759,7 +759,7 @@ class BinaryClassificationHeadTest(test.TestCase): self._assert_output_alternatives(model_fn_ops) self.assertIsNone(model_fn_ops.train_op) _assert_no_variables(self) - _assert_summary_tags(self, ["binary_logistic_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = .81326175 _assert_metrics(self, expected_loss, self._expected_eval_metrics(expected_loss), model_fn_ops) @@ -838,7 +838,7 @@ class BinaryClassificationHeadTest(test.TestCase): logits=self._logits) self._assert_output_alternatives(model_fn_ops) _assert_no_variables(self) - _assert_summary_tags(self, ["binary_logistic_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = .81326175 _assert_metrics(self, expected_loss, self._expected_eval_metrics(expected_loss), model_fn_ops) @@ -859,7 +859,7 @@ class BinaryClassificationHeadTest(test.TestCase): logits=self._logits) self._assert_output_alternatives(model_fn_ops) _assert_no_variables(self) - _assert_summary_tags(self, ["binary_logistic_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_total_loss = .31326166 _assert_metrics( self, @@ -892,7 +892,7 @@ class BinaryClassificationHeadTest(test.TestCase): logits=self._logits) self._assert_output_alternatives(model_fn_ops) _assert_no_variables(self) - _assert_summary_tags(self, ["binary_logistic_head/loss"]) + _assert_summary_tags(self, ["loss"]) # logloss: z:label, x:logit # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) # expected_loss is (total_weighted_loss)/1 since htere is 1 nonzero @@ -932,7 +932,7 @@ class BinaryClassificationHeadTest(test.TestCase): expected_trainable=("binary_logistic_head/centered_bias_weight:0",)) variables.global_variables_initializer().run() _assert_summary_tags(self, [ - "binary_logistic_head/loss", + "loss", "binary_logistic_head/centered_bias/bias_0" ]) expected_loss = .81326175 @@ -983,7 +983,7 @@ class MultiClassHeadTest(test.TestCase): logits=self._logits) self._assert_output_alternatives(model_fn_ops) _assert_no_variables(self) - _assert_summary_tags(self, ["multi_class_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = 1.5514447 _assert_metrics(self, expected_loss, self._expected_eval_metrics(expected_loss), model_fn_ops) @@ -1022,7 +1022,7 @@ class MultiClassHeadTest(test.TestCase): _assert_variables( self, expected_global=w, expected_model=w, expected_trainable=w) variables.global_variables_initializer().run() - _assert_summary_tags(self, ["multi_class_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = 1.0986123 _assert_metrics(self, expected_loss, { "accuracy": 0., @@ -1073,7 +1073,7 @@ class MultiClassHeadTest(test.TestCase): expected_trainable=("multi_class_head/centered_bias_weight:0",)) variables.global_variables_initializer().run() _assert_summary_tags(self, - ["multi_class_head/loss", + ["loss", "multi_class_head/centered_bias/bias_0", "multi_class_head/centered_bias/bias_1", "multi_class_head/centered_bias/bias_2"]) @@ -1091,7 +1091,7 @@ class MultiClassHeadTest(test.TestCase): self._assert_output_alternatives(model_fn_ops) self.assertIsNone(model_fn_ops.train_op) _assert_no_variables(self) - _assert_summary_tags(self, ["multi_class_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = 1.5514447 _assert_metrics(self, expected_loss, self._expected_eval_metrics(expected_loss), model_fn_ops) @@ -1110,7 +1110,7 @@ class MultiClassHeadTest(test.TestCase): self._assert_output_alternatives(model_fn_ops) self.assertIsNone(model_fn_ops.train_op) _assert_no_variables(self) - _assert_summary_tags(self, ["multi_class_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = 3.1698461 expected_eval_metrics = { "accuracy": 0., @@ -1149,7 +1149,7 @@ class MultiClassHeadTest(test.TestCase): logits=self._logits) self._assert_output_alternatives(model_fn_ops) _assert_no_variables(self) - _assert_summary_tags(self, ["multi_class_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = 1.5514447 _assert_metrics(self, expected_loss * weight, self._expected_eval_metrics(expected_loss), model_fn_ops) @@ -1173,7 +1173,7 @@ class MultiClassHeadTest(test.TestCase): logits=self._logits) self._assert_output_alternatives(model_fn_ops) _assert_no_variables(self) - _assert_summary_tags(self, ["multi_class_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = 1.5514447 * weight _assert_metrics(self, expected_loss, self._expected_eval_metrics(expected_loss), model_fn_ops) @@ -1280,7 +1280,7 @@ class MultiClassHeadTest(test.TestCase): data_flow_ops.tables_initializer().run() self.assertIsNone(model_fn_ops.train_op) _assert_no_variables(self) - _assert_summary_tags(self, ["multi_class_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = 1.5514447 expected_eval_metrics = { "accuracy": 0., @@ -1306,7 +1306,7 @@ class MultiClassHeadTest(test.TestCase): data_flow_ops.tables_initializer().run() self.assertIsNone(model_fn_ops.train_op) _assert_no_variables(self) - _assert_summary_tags(self, ["multi_class_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = 0.5514447 expected_eval_metrics = { "accuracy": 1., @@ -1345,7 +1345,7 @@ class BinarySvmHeadTest(test.TestCase): logits=self._predictions) self._assert_output_alternatives(model_fn_ops) _assert_no_variables(self) - _assert_summary_tags(self, ["binary_svm_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = np.average(self._expected_losses) _assert_metrics(self, expected_loss, { "accuracy": 1., @@ -1375,7 +1375,7 @@ class BinarySvmHeadTest(test.TestCase): _assert_variables( self, expected_global=w, expected_model=w, expected_trainable=w) variables.global_variables_initializer().run() - _assert_summary_tags(self, ["binary_svm_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = 1. _assert_metrics(self, expected_loss, { "accuracy": .5, @@ -1407,7 +1407,7 @@ class BinarySvmHeadTest(test.TestCase): self._assert_output_alternatives(model_fn_ops) self.assertIsNone(model_fn_ops.train_op) _assert_no_variables(self) - _assert_summary_tags(self, ["binary_svm_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = np.average(self._expected_losses) _assert_metrics(self, expected_loss, { "accuracy": 1., @@ -1426,7 +1426,7 @@ class BinarySvmHeadTest(test.TestCase): logits=self._predictions) self._assert_output_alternatives(model_fn_ops) _assert_no_variables(self) - _assert_summary_tags(self, ["binary_svm_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_loss = np.average(self._expected_losses) _assert_metrics(self, expected_loss, { "accuracy": 1., @@ -1446,7 +1446,7 @@ class BinarySvmHeadTest(test.TestCase): logits=self._predictions) self._assert_output_alternatives(model_fn_ops) _assert_no_variables(self) - _assert_summary_tags(self, ["binary_svm_head/loss"]) + _assert_summary_tags(self, ["loss"]) expected_weighted_losses = np.multiply(weights, self._expected_losses) _assert_metrics(self, np.mean(expected_weighted_losses), { "accuracy": 1., @@ -1473,7 +1473,7 @@ class BinarySvmHeadTest(test.TestCase): expected_trainable=("binary_svm_head/centered_bias_weight:0",)) variables.global_variables_initializer().run() _assert_summary_tags(self, [ - "binary_svm_head/loss", + "loss", "binary_svm_head/centered_bias/bias_0" ]) expected_loss = np.average(self._expected_losses) From dbdc9fc2ed2e8f9b247c8b2980b0bcd079e39359 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo <ebrevdo@google.com> Date: Mon, 1 May 2017 11:55:26 -0800 Subject: [PATCH 24/38] Add should-use for commonly misused ops. Fixed a bunch of invalid callers (by initially using should_use_with_fatal and looking for failing unit tests) Change: 154748211 --- .../python/ops/quantized_distribution.py | 2 +- .../python/learn/estimators/estimator.py | 3 -- .../kernel_tests/tensor_array_ops_test.py | 2 +- tensorflow/python/ops/control_flow_ops.py | 2 + tensorflow/python/ops/image_ops_impl.py | 38 ++++++++++-------- tensorflow/python/ops/image_ops_test.py | 14 +------ tensorflow/python/ops/resources.py | 2 + tensorflow/python/ops/tensor_array_ops.py | 6 +++ tensorflow/python/ops/variables.py | 7 ++++ tensorflow/python/util/tf_should_use.py | 39 ++++++++++++++++--- tensorflow/python/util/tf_should_use_test.py | 8 ++-- 11 files changed, 81 insertions(+), 42 deletions(-) diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py index ac027f0f43e..8aebb79b913 100644 --- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py @@ -232,7 +232,7 @@ class QuantizedDistribution(distributions.Distribution): graph_parents = self._dist._graph_parents # pylint: disable=protected-access checks = [] - if low is not None and high is not None: + if validate_args and low is not None and high is not None: message = "low must be strictly less than high." checks.append( check_ops.assert_less( diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 4ca3120deca..123db50d325 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -1287,9 +1287,6 @@ class Estimator(BaseEstimator): else: saver_for_restore = saver.Saver(sharded=True) with tf_session.Session('') as session: - variables.initialize_local_variables() - data_flow_ops.tables_initializer() - resources.initialize_resources(resources.shared_resources()) saver_for_restore.restore(session, checkpoint_path) init_op = control_flow_ops.group( variables.local_variables_initializer(), diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py index 0fec42e1dba..41fe29e006f 100644 --- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py +++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py @@ -1128,7 +1128,7 @@ class TensorArrayTest(test.TestCase): dtype=dtypes.float32, size=0, dynamic_size=False, infer_shape=True) self.assertEqual(0, ta.size().eval()) # Don't actually perform the pack. This stores the static shape. - ta.unstack(array_ops.zeros([0, 3, 5])) + ta.unstack(array_ops.zeros([0, 3, 5])).mark_used() packed = ta.stack() self.assertAllEqual([0, 3, 5], packed.eval().shape) # Concatenating zero tensors along their first dimension gives a diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 75a537643e6..ebe5259de5b 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -72,6 +72,7 @@ from tensorflow.python.ops.gen_control_flow_ops import * # pylint: enable=wildcard-import from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +from tensorflow.python.util import tf_should_use # We override the 'tuple' for a control flow op, so we keep python's @@ -84,6 +85,7 @@ _basetuple = tuple # Assert and Print are special symbols in python, so we must # use an upper-case version of them. +@tf_should_use.should_use_result def Assert(condition, data, summarize=None, name=None): """Asserts that the given condition is true. diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 78621d3b570..c79f413c5ee 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -218,7 +218,8 @@ def random_flip_up_down(image, seed=None): ValueError: if the shape of `image` not supported. """ image = ops.convert_to_tensor(image, name='image') - _Check3DImage(image, require_static=False) + image = control_flow_ops.with_dependencies( + _Check3DImage(image, require_static=False), image) uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed) mirror_cond = math_ops.less(uniform_random, .5) result = control_flow_ops.cond(mirror_cond, @@ -246,7 +247,8 @@ def random_flip_left_right(image, seed=None): ValueError: if the shape of `image` not supported. """ image = ops.convert_to_tensor(image, name='image') - _Check3DImage(image, require_static=False) + image = control_flow_ops.with_dependencies( + _Check3DImage(image, require_static=False), image) uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed) mirror_cond = math_ops.less(uniform_random, .5) result = control_flow_ops.cond(mirror_cond, @@ -273,7 +275,8 @@ def flip_left_right(image): ValueError: if the shape of `image` not supported. """ image = ops.convert_to_tensor(image, name='image') - _Check3DImage(image, require_static=False) + image = control_flow_ops.with_dependencies( + _Check3DImage(image, require_static=False), image) return fix_image_flip_shape(image, array_ops.reverse(image, [1])) @@ -295,7 +298,8 @@ def flip_up_down(image): ValueError: if the shape of `image` not supported. """ image = ops.convert_to_tensor(image, name='image') - _Check3DImage(image, require_static=False) + image = control_flow_ops.with_dependencies( + _Check3DImage(image, require_static=False), image) return fix_image_flip_shape(image, array_ops.reverse(image, [0])) @@ -312,7 +316,8 @@ def rot90(image, k=1, name=None): """ with ops.name_scope(name, 'rot90', [image, k]) as scope: image = ops.convert_to_tensor(image, name='image') - _Check3DImage(image, require_static=False) + image = control_flow_ops.with_dependencies( + _Check3DImage(image, require_static=False), image) k = ops.convert_to_tensor(k, dtype=dtypes.int32, name='k') k.get_shape().assert_has_rank(0) k = math_ops.mod(k, 4) @@ -350,7 +355,8 @@ def transpose_image(image): ValueError: if the shape of `image` not supported. """ image = ops.convert_to_tensor(image, name='image') - _Check3DImage(image, require_static=False) + image = control_flow_ops.with_dependencies( + _Check3DImage(image, require_static=False), image) return array_ops.transpose(image, [1, 0, 2], name='transpose_image') @@ -379,12 +385,14 @@ def central_crop(image, central_fraction): 3-D float Tensor """ image = ops.convert_to_tensor(image, name='image') - _Check3DImage(image, require_static=False) if central_fraction <= 0.0 or central_fraction > 1.0: raise ValueError('central_fraction must be within (0, 1]') if central_fraction == 1.0: return image + image = control_flow_ops.with_dependencies( + _Check3DImage(image, require_static=False), image) + img_shape = array_ops.shape(image) depth = image.get_shape()[2] fraction_offset = int(1 / ((1 - central_fraction) / 2.0)) @@ -435,9 +443,6 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height, """ image = ops.convert_to_tensor(image, name='image') - assert_ops = [] - assert_ops += _CheckAtLeast3DImage(image, require_static=False) - is_batch = True image_shape = image.get_shape() if image_shape.ndims == 3: @@ -450,6 +455,8 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height, elif image_shape.ndims != 4: raise ValueError('\'image\' must have either 3 or 4 dimensions.') + assert_ops = _CheckAtLeast3DImage(image, require_static=False) + batch, height, width, depth = _ImageDimensions(image, rank=4) after_padding_width = target_width - offset_width - width @@ -515,9 +522,6 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height, """ image = ops.convert_to_tensor(image, name='image') - assert_ops = [] - assert_ops += _CheckAtLeast3DImage(image, require_static=False) - is_batch = True image_shape = image.get_shape() if image_shape.ndims == 3: @@ -530,6 +534,8 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height, elif image_shape.ndims != 4: raise ValueError('\'image\' must have either 3 or 4 dimensions.') + assert_ops = _CheckAtLeast3DImage(image, require_static=False) + batch, height, width, depth = _ImageDimensions(image, rank=4) assert_ops += _assert(offset_width >= 0, ValueError, @@ -602,8 +608,7 @@ def resize_image_with_crop_or_pad(image, target_height, target_width): elif image_shape.ndims != 4: raise ValueError('\'image\' must have either 3 or 4 dimensions.') - assert_ops = [] - assert_ops += _CheckAtLeast3DImage(image, require_static=False) + assert_ops = _CheckAtLeast3DImage(image, require_static=False) assert_ops += _assert(target_width > 0, ValueError, 'target_width must be > 0.') assert_ops += _assert(target_height > 0, ValueError, @@ -800,7 +805,8 @@ def per_image_standardization(image): ValueError: if the shape of 'image' is incompatible with this function. """ image = ops.convert_to_tensor(image, name='image') - _Check3DImage(image, require_static=False) + image = control_flow_ops.with_dependencies( + _Check3DImage(image, require_static=False), image) num_pixels = math_ops.reduce_prod(array_ops.shape(image)) image = math_ops.cast(image, dtype=dtypes.float32) diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 716d767b0ce..1a70d465077 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -1175,12 +1175,7 @@ class CropToBoundingBoxTest(test_util.TensorFlowTestCase): offset_height, offset_width = [0, 0] target_height, target_width = [2, 2] - for x_shape in ([3, 5],): - self._assertRaises(x, x_shape, offset_height, offset_width, target_height, - target_width, - "'image' must be at least three-dimensional.") - - for x_shape in ([1, 3, 5, 1, 1],): + for x_shape in ([3, 5], [1, 3, 5, 1, 1]): self._assertRaises(x, x_shape, offset_height, offset_width, target_height, target_width, "'image' must have either 3 or 4 dimensions.") @@ -1426,12 +1421,7 @@ class PadToBoundingBoxTest(test_util.TensorFlowTestCase): offset_height, offset_width = [0, 0] target_height, target_width = [2, 2] - for x_shape in ([3, 5],): - self._assertRaises(x, x_shape, offset_height, offset_width, target_height, - target_width, - "'image' must be at least three-dimensional") - - for x_shape in ([1, 3, 5, 1, 1],): + for x_shape in ([3, 5], [1, 3, 5, 1, 1]): self._assertRaises(x, x_shape, offset_height, offset_width, target_height, target_width, "'image' must have either 3 or 4 dimensions.") diff --git a/tensorflow/python/ops/resources.py b/tensorflow/python/ops/resources.py index 41fb8a74a9e..57ba0084e84 100644 --- a/tensorflow/python/ops/resources.py +++ b/tensorflow/python/ops/resources.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import tf_should_use _Resource = collections.namedtuple("_Resource", @@ -98,6 +99,7 @@ def report_uninitialized_resources(resource_list=None, return array_ops.boolean_mask(variable_names_tensor, variables_mask) +@tf_should_use.should_use_result def initialize_resources(resource_list, name="init"): """Initializes the resources in the given list. diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index d1013c53ddf..b1c7d74a0cb 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import tf_should_use def _maybe_set_device(handle_op, value_t): @@ -252,6 +253,7 @@ class TensorArray(object): value.set_shape(self._element_shape[0].dims) return value + @tf_should_use.should_use_result def write(self, index, value, name=None): """Write `value` into index `index` of the TensorArray. @@ -358,6 +360,7 @@ class TensorArray(object): value.set_shape([None] + self._element_shape[0].dims[1:]) return value + @tf_should_use.should_use_result def unstack(self, value, name=None): """Unstack the values of a `Tensor` in the TensorArray. @@ -380,6 +383,7 @@ class TensorArray(object): return self.scatter( indices=math_ops.range(0, num_elements), value=value, name=name) + @tf_should_use.should_use_result def scatter(self, indices, value, name=None): """Scatter the values of a `Tensor` in specific indices of a `TensorArray`. @@ -418,6 +422,7 @@ class TensorArray(object): ta._merge_element_shape(element_shape) return ta + @tf_should_use.should_use_result def split(self, value, lengths, name=None): """Split the values of a `Tensor` into the TensorArray. @@ -466,6 +471,7 @@ class TensorArray(object): return gen_data_flow_ops._tensor_array_size_v3( handle=self._handle, flow_in=self.flow, name=name) + @tf_should_use.should_use_result def close(self, name=None): """Close the current TensorArray.""" with ops.colocate_with(self._handle): diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 8b508e45a40..33523f1a717 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.util import compat +from tensorflow.python.util import tf_should_use from tensorflow.python.util.deprecation import deprecated @@ -1152,6 +1153,7 @@ def variables_initializer(var_list, name="init"): return control_flow_ops.no_op(name=name) +@tf_should_use.should_use_result @deprecated("2017-03-02", "Use `tf.variables_initializer` instead.") def initialize_variables(var_list, name="init"): """See `tf.variables_initializer`.""" @@ -1169,6 +1171,7 @@ def global_variables_initializer(): return variables_initializer(global_variables()) +@tf_should_use.should_use_result @deprecated("2017-03-02", "Use `tf.global_variables_initializer` instead.") def initialize_all_variables(): """See `tf.global_variables_initializer`.""" @@ -1186,12 +1189,14 @@ def local_variables_initializer(): return variables_initializer(local_variables()) +@tf_should_use.should_use_result @deprecated("2017-03-02", "Use `tf.local_variables_initializer` instead.") def initialize_local_variables(): """See `tf.local_variables_initializer`.""" return local_variables_initializer() +@tf_should_use.should_use_result def is_variable_initialized(variable): """Tests if a variable has been initialized. @@ -1205,6 +1210,7 @@ def is_variable_initialized(variable): return state_ops.is_variable_initialized(variable) +@tf_should_use.should_use_result def assert_variables_initialized(var_list=None): """Returns an Op to check if variables are initialized. @@ -1246,6 +1252,7 @@ def assert_variables_initialized(var_list=None): return array_ops.stack(ranks) +@tf_should_use.should_use_result def report_uninitialized_variables(var_list=None, name="report_uninitialized_variables"): """Adds ops to list the names of uninitialized variables. diff --git a/tensorflow/python/util/tf_should_use.py b/tensorflow/python/util/tf_should_use.py index a6a1ad48920..88df3351e66 100644 --- a/tensorflow/python/util/tf_should_use.py +++ b/tensorflow/python/util/tf_should_use.py @@ -22,6 +22,7 @@ import traceback import types from tensorflow.python.platform import tf_logging +from tensorflow.python.util import tf_decorator def _add_should_use_warning(x, fatal_error=False): @@ -36,6 +37,13 @@ def _add_should_use_warning(x, fatal_error=False): An instance of `TFShouldUseWarningWrapper` which subclasses `type(x)` and is a very shallow wrapper for `x` which logs access into `x`. """ + if x is None: # special corner case where x is None + return x + has_been_used = getattr(x, '_tf_object_has_been_used', None) + if has_been_used is not None: + x._tf_object_has_been_used = has_been_used # pylint: disable=protected-access + return x + def override_method(method): def fn(self, *args, **kwargs): self._tf_object_has_been_used = True # pylint: disable=protected-access @@ -67,18 +75,27 @@ def _add_should_use_warning(x, fatal_error=False): else: logger = tf_logging.error logger( - 'Object was never used: %s.\nIt was originally created here:\n%s' - % (self, self._tf_object_creation_stack)) + '==================================\n' + 'Object was never used (type %s):\n%s\nIf you want to mark it as ' + 'used call its "mark_used()" method.\nIt was originally created ' + 'here:\n%s\n' + '==================================' % + (type(x), x, self._tf_object_creation_stack)) if hasattr(super(TFShouldUseWarningWrapper, self), '__del__'): return super(TFShouldUseWarningWrapper, self).__del__() + + def mark_used(self, *args, **kwargs): + self._tf_object_has_been_used = True + if hasattr(super(TFShouldUseWarningWrapper, self), 'mark_used'): + return super(TFShouldUseWarningWrapper, self).mark_used(*args, **kwargs) # pylint: enable=super-on-old-class for name in dir(TFShouldUseWarningWrapper): method = getattr(TFShouldUseWarningWrapper, name) if not isinstance(method, types.FunctionType): continue - if name in ('__init__', '__getattribute__', '__del__'): + if name in ('__init__', '__getattribute__', '__del__', 'mark_used'): continue setattr(TFShouldUseWarningWrapper, name, functools.wraps(method)(override_method(method))) @@ -114,7 +131,13 @@ def should_use_result(fn): """ def wrapped(*args, **kwargs): return _add_should_use_warning(fn(*args, **kwargs)) - return functools.wraps(fn)(wrapped) + return tf_decorator.make_decorator( + fn, wrapped, 'should_use_result', + ((fn.__doc__ or '') + + ('\n\n ' + '**NOTE** The output of this function should be used. If it is not, ' + 'a warning will be logged. To mark the output as used, ' + 'call its .mark_used() method.'))) def must_use_result_or_fatal(fn): @@ -142,4 +165,10 @@ def must_use_result_or_fatal(fn): """ def wrapped(*args, **kwargs): return _add_should_use_warning(fn(*args, **kwargs), fatal_error=True) - return functools.wraps(fn)(wrapped) + return tf_decorator.make_decorator( + fn, wrapped, 'must_use_result_or_fatal', + ((fn.__doc__ or '') + + ('\n\n ' + '**NOTE** The output of this function must be used. If it is not, ' + 'a fatal error will be raised. To mark the output as used, ' + 'call its .mark_used() method.'))) diff --git a/tensorflow/python/util/tf_should_use_test.py b/tensorflow/python/util/tf_should_use_test.py index 09130eed3a5..71d48e3dde3 100644 --- a/tensorflow/python/util/tf_should_use_test.py +++ b/tensorflow/python/util/tf_should_use_test.py @@ -52,7 +52,7 @@ class TfShouldUseTest(test.TestCase): h = tf_should_use._add_should_use_warning(c) del h in_this_function() - self.assertIn('Object was never used:', '\n'.join(captured)) + self.assertIn('Object was never used', '\n'.join(captured)) self.assertIn('blah:0', '\n'.join(captured)) self.assertIn('in_this_function', '\n'.join(captured)) @@ -63,7 +63,7 @@ class TfShouldUseTest(test.TestCase): h = tf_should_use._add_should_use_warning(c) fn(h) del h - self.assertNotIn('Object was never used:', '\n'.join(captured)) + self.assertNotIn('Object was never used', '\n'.join(captured)) self.assertNotIn('blah:0', '\n'.join(captured)) def testAddShouldUseWarningWhenUsedWithAdd(self): @@ -83,7 +83,7 @@ class TfShouldUseTest(test.TestCase): captured = [] with reroute_error(captured): return_const(0.0) - self.assertIn('Object was never used:', '\n'.join(captured)) + self.assertIn('Object was never used', '\n'.join(captured)) self.assertIn('blah:0', '\n'.join(captured)) self.assertIn('return_const', '\n'.join(captured)) @@ -99,7 +99,7 @@ class TfShouldUseTest(test.TestCase): # unused op as being "used". v = constant_op.constant(1.0, name='meh') v.eval() - self.assertIn('Object was never used:', '\n'.join(captured)) + self.assertIn('Object was never used', '\n'.join(captured)) self.assertIn('blah:0', '\n'.join(captured)) self.assertIn('return_const', '\n'.join(captured)) From 0388bb0e2d7c4ddeef7c51d1a1eafebfbd8f2ed4 Mon Sep 17 00:00:00 2001 From: Olivia Nordquist <nolivia@google.com> Date: Mon, 1 May 2017 12:19:53 -0800 Subject: [PATCH 25/38] removing lesser known latin abbreviation Change: 154751087 --- tensorflow/docs_src/programmers_guide/variables.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/docs_src/programmers_guide/variables.md b/tensorflow/docs_src/programmers_guide/variables.md index 04bfca5f3bd..e8d1e519f0b 100644 --- a/tensorflow/docs_src/programmers_guide/variables.md +++ b/tensorflow/docs_src/programmers_guide/variables.md @@ -62,7 +62,7 @@ with tf.device("/job:ps/task:7"): v = tf.Variable(...) ``` -**N.B.** Operations that mutate a variable, such as +**NOTE** Operations that mutate a variable, such as @{tf.Variable.assign} and the parameter update operations in a @{tf.train.Optimizer} *must* run on From 54d78fdb9461111af31dc2c8d08db833a6402e17 Mon Sep 17 00:00:00 2001 From: Alexandre Passos <apassos@google.com> Date: Mon, 1 May 2017 12:30:30 -0800 Subject: [PATCH 26/38] Fixes latex in math ops' documentation. Change: 154752402 --- tensorflow/core/ops/math_ops.cc | 47 +++++++++++------------ tensorflow/python/ops/special_math_ops.py | 15 ++++---- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index aafbd1b5cce..494358fec69 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -662,13 +662,12 @@ Compute the upper regularized incomplete Gamma function `Q(a, x)`. The upper regularized incomplete Gamma function is defined as: -``` -Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x) -``` +\\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\\) + where -``` -Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt -``` + +\\(Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt\\) + is the upper incomplete Gama function. Note, above `P(a, x)` (`Igamma`) is the lower regularized complete @@ -686,13 +685,13 @@ Compute the lower regularized incomplete Gamma function `Q(a, x)`. The lower regularized incomplete Gamma function is defined as: -``` -P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x) -``` + +\\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\) + where -``` -gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt -``` + +\\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\) + is the lower incomplete Gamma function. Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete @@ -710,9 +709,9 @@ Compute the Hurwitz zeta function \\(\zeta(x, q)\\). The Hurwitz zeta function is defined as: -``` -\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x} -``` + +\\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\\) + )doc"); REGISTER_OP("Polygamma") @@ -726,9 +725,9 @@ Compute the polygamma function \\(\psi^{(n)}(x)\\). The polygamma function is defined as: -``` -\psi^{(n)}(x) = \frac{d^n}{dx^n} \psi(x) -``` + +\\(\psi^{(n)}(x) = \frac{d^n}{dx^n} \psi(x)\\) + where \\(\psi(x)\\) is the digamma function. )doc"); @@ -775,14 +774,14 @@ Compute the regularized incomplete beta integral \\(I_x(a, b)\\). The regularized incomplete beta integral is defined as: -``` -I_x(a, b) = \frac{B(x; a, b)}{B(a, b)} -``` + +\\(I_x(a, b) = \frac{B(x; a, b)}{B(a, b)}\\) + where -``` -B(x; a, b) = \int_0^x t^{a-1} (1 - t)^{b-1} dt -``` + +\\(B(x; a, b) = \int_0^x t^{a-1} (1 - t)^{b-1} dt\\) + is the incomplete beta function and \\(B(a, b)\\) is the *complete* beta function. diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py index e24246464ec..851fba0beba 100644 --- a/tensorflow/python/ops/special_math_ops.py +++ b/tensorflow/python/ops/special_math_ops.py @@ -35,19 +35,20 @@ from tensorflow.python.platform import tf_logging as logging # TODO(b/27419586) Change docstring for required dtype of x once int allowed def lbeta(x, name='lbeta'): - r"""Computes `ln(|Beta(x)|)`, reducing along the last dimension. + r"""Computes \\(ln(|Beta(x)|)\\), reducing along the last dimension. Given one-dimensional `z = [z_0,...,z_{K-1}]`, we define - ```Beta(z) = \prod_j Gamma(z_j) / Gamma(\sum_j z_j)``` + $$Beta(z) = \prod_j Gamma(z_j) / Gamma(\sum_j z_j)$$ And for `n + 1` dimensional `x` with shape `[N1, ..., Nn, K]`, we define - `lbeta(x)[i1, ..., in] = Log(|Beta(x[i1, ..., in, :])|)`. In other words, - the last dimension is treated as the `z` vector. + $$lbeta(x)[i1, ..., in] = Log(|Beta(x[i1, ..., in, :])|)$$. + + In other words, the last dimension is treated as the `z` vector. Note that if `z = [u, v]`, then - `Beta(z) = int_0^1 t^{u-1} (1 - t)^{v-1} dt`, which defines the traditional - bivariate beta function. + \\(Beta(z) = int_0^1 t^{u-1} (1 - t)^{v-1} dt\\), which defines the + traditional bivariate beta function. If the last dimension is empty, we follow the convention that the sum over the empty set is zero, and the product is one. @@ -57,7 +58,7 @@ def lbeta(x, name='lbeta'): name: A name for the operation (optional). Returns: - The logarithm of `|Beta(x)|` reducing along the last dimension. + The logarithm of \\(|Beta(x)|\\) reducing along the last dimension. """ # In the event that the last dimension has zero entries, we return -inf. # This is consistent with a convention that the sum over the empty set 0, and From 02c0a9651cc9d956af866ca40bad4b6c5f129920 Mon Sep 17 00:00:00 2001 From: Toby Boyd <tobyboyd@google.com> Date: Mon, 1 May 2017 12:30:33 -0800 Subject: [PATCH 27/38] Fixed minor spelling errors and an incorrect link. Change: 154752409 --- .../docs_src/performance/performance_models.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tensorflow/docs_src/performance/performance_models.md b/tensorflow/docs_src/performance/performance_models.md index 18f407c702e..71c4e6cfe06 100644 --- a/tensorflow/docs_src/performance/performance_models.md +++ b/tensorflow/docs_src/performance/performance_models.md @@ -1,4 +1,4 @@ -# TensorFlow High-Performance Models +# High-Performance Models TensorFlow is a powerful and flexible machine learning platform. It can be used to distribute model training and inference across a large number @@ -29,7 +29,7 @@ We divide our effort to build high-performance models into three categories: ## Input Pipeline -The input pipeline is the part of a tensorflow program that reads input data, +The input pipeline is the part of a TensorFlow program that reads input data, shuffles it, and preprocesses it. Among the most important features to build a fast input pipeline: @@ -149,7 +149,7 @@ stage, and push one set at end end. For example: if there are three stages: A, B and C. There are two staging areas in between: S1 and S2. -During the warmup, we run: +During the warm up, we run: ``` Warm up: @@ -162,14 +162,14 @@ Step 4: A3 B2 C1 Step 5: A4 B3 C2 ``` -After the warmup, S1 and S2 each have one set of data in them. +After the warm up, S1 and S2 each have one set of data in them. For each step of the actual execution, one set of data is consumed from each staging area, and one set is added to each. There are a few nice properties about the scheme: * All the stages are non-blocking, since the staging areas always have one set -of data after the warmup. +of data after the warm up. * Each stage can run in parallel since they can all start immediately. * The staging buffers have a fixed memory overhead. They will have at most one extra set of data. @@ -292,7 +292,7 @@ devices and the fully aggregated gradient is then applied to each local copy. Gradient aggregation across the server can be done in different ways: -* Using standard tensorflow operations to accumulate the total on a single +* Using standard TensorFlow operations to accumulate the total on a single device (CPU or GPU) and then copy it back to all GPUs. * Using NVIDIA NCCL, described below in the NCCL section. @@ -334,7 +334,7 @@ This is available in the benchmark scripts as the 'distributed_replicated' variable_update mode. { +../images/perf_distributed_replicated_mode_doc.png){ width="900" style="max-width: inherit"} #### NCCL From 6155cd1ccfd1c5308efbb3691564790295da47fe Mon Sep 17 00:00:00 2001 From: Olivia Nordquist <nolivia@google.com> Date: Mon, 1 May 2017 12:45:19 -0800 Subject: [PATCH 28/38] Fixes #8666 which removes referral to a deleted function Change: 154754448 --- tensorflow/python/ops/rnn.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 1051478a7f7..f7e98c941a9 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -288,11 +288,10 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, swap_memory=False, time_major=False, scope=None): """Creates a dynamic version of bidirectional recurrent neural network. - Similar to the unidirectional case above (rnn) but takes input and builds - independent forward and backward RNNs. The input_size of forward and - backward cell must match. The initial state for both directions is zero by - default (but can be set optionally) and no intermediate states are ever - returned -- the network is fully unrolled for the given (passed in) + Takes input and builds independent forward and backward RNNs. The input_size + of forward and backward cell must match. The initial state for both directions + is zero by default (but can be set optionally) and no intermediate states are + ever returned -- the network is fully unrolled for the given (passed in) length(s) of the sequence(s) or completely unrolled if length(s) is not given. @@ -419,12 +418,10 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, time_major=False, scope=None): """Creates a recurrent neural network specified by RNNCell `cell`. - This function is functionally identical to the function `rnn` above, but - performs fully dynamic unrolling of `inputs`. + Performs fully dynamic unrolling of `inputs`. - Unlike `rnn`, the input `inputs` is not a Python list of `Tensors`, one for - each frame. Instead, `inputs` may be a single `Tensor` where - the maximum time is either the first or second dimension (see the parameter + `Inputs` may be a single `Tensor` where the maximum time is either the first + or second dimension (see the parameter `time_major`). Alternatively, it may be a (possibly nested) tuple of Tensors, each of them having matching batch and time dimensions. The corresponding output is either a single `Tensor` having the same number @@ -433,7 +430,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, The parameter `sequence_length` is optional and is used to copy-through state and zero-out outputs when past a batch element's sequence length. So it's more - for correctness than performance, unlike in rnn(). + for correctness than performance. Args: cell: An instance of RNNCell. From 0793ebb84b3463c349db25d2c740c65b10108e45 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 1 May 2017 12:48:13 -0800 Subject: [PATCH 29/38] Update ops-related pbtxt files. Change: 154754897 --- tensorflow/core/ops/ops.pbtxt | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index f6eb0cd7a85..b452ab7ef5c 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -3324,7 +3324,7 @@ op { } } summary: "Compute the regularized incomplete beta integral \\\\(I_x(a, b)\\\\)." - description: "The regularized incomplete beta integral is defined as:\n\n```\nI_x(a, b) = \\frac{B(x; a, b)}{B(a, b)}\n```\nwhere\n\n```\nB(x; a, b) = \\int_0^x t^{a-1} (1 - t)^{b-1} dt\n```\n\nis the incomplete beta function and \\\\(B(a, b)\\\\) is the *complete*\nbeta function." + description: "The regularized incomplete beta integral is defined as:\n\n\n\\\\(I_x(a, b) = \\frac{B(x; a, b)}{B(a, b)}\\\\)\n\nwhere\n\n\n\\\\(B(x; a, b) = \\int_0^x t^{a-1} (1 - t)^{b-1} dt\\\\)\n\n\nis the incomplete beta function and \\\\(B(a, b)\\\\) is the *complete*\nbeta function." } op { name: "BiasAdd" @@ -9031,7 +9031,7 @@ op { } } summary: "Compute the lower regularized incomplete Gamma function `Q(a, x)`." - description: "The lower regularized incomplete Gamma function is defined as:\n\n```\nP(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\n```\nwhere\n```\ngamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\n```\nis the lower incomplete Gamma function.\n\nNote, above `Q(a, x)` (`Igammac`) is the upper regularized complete\nGamma function." + description: "The lower regularized incomplete Gamma function is defined as:\n\n\n\\\\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\\\)\n\nwhere\n\n\\\\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\\\)\n\nis the lower incomplete Gamma function.\n\nNote, above `Q(a, x)` (`Igammac`) is the upper regularized complete\nGamma function." } op { name: "Igammac" @@ -9058,7 +9058,7 @@ op { } } summary: "Compute the upper regularized incomplete Gamma function `Q(a, x)`." - description: "The upper regularized incomplete Gamma function is defined as:\n\n```\nQ(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\n```\nwhere\n```\nGamma(a, x) = int_{x}^{\\infty} t^{a-1} exp(-t) dt\n```\nis the upper incomplete Gama function.\n\nNote, above `P(a, x)` (`Igamma`) is the lower regularized complete\nGamma function." + description: "The upper regularized incomplete Gamma function is defined as:\n\n\\\\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\\\\)\n\nwhere\n\n\\\\(Gamma(a, x) = int_{x}^{\\infty} t^{a-1} exp(-t) dt\\\\)\n\nis the upper incomplete Gama function.\n\nNote, above `P(a, x)` (`Igamma`) is the lower regularized complete\nGamma function." } op { name: "Imag" @@ -13342,7 +13342,7 @@ op { } } summary: "Compute the polygamma function \\\\(\\psi^{(n)}(x)\\\\)." - description: "The polygamma function is defined as:\n\n```\n\\psi^{(n)}(x) = \\frac{d^n}{dx^n} \\psi(x)\n```\nwhere \\\\(\\psi(x)\\\\) is the digamma function." + description: "The polygamma function is defined as:\n\n\n\\\\(\\psi^{(n)}(x) = \\frac{d^n}{dx^n} \\psi(x)\\\\)\n\nwhere \\\\(\\psi(x)\\\\) is the digamma function." } op { name: "Pow" @@ -26985,5 +26985,5 @@ op { } } summary: "Compute the Hurwitz zeta function \\\\(\\zeta(x, q)\\\\)." - description: "The Hurwitz zeta function is defined as:\n\n```\n\\zeta(x, q) = \\sum_{n=0}^{\\infty} (q + n)^{-x}\n```" + description: "The Hurwitz zeta function is defined as:\n\n\n\\\\(\\zeta(x, q) = \\sum_{n=0}^{\\infty} (q + n)^{-x}\\\\)" } From 46c3f9b0af6f223008dfd9dd720905ff9fccfdd8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 1 May 2017 12:53:19 -0800 Subject: [PATCH 30/38] Go: Update generated wrapper functions for TensorFlow ops. Change: 154755628 --- tensorflow/go/op/wrappers.go | 46 +++++++++++++++++------------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 4caa75c930f..c63be8bc5ee 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -14316,9 +14316,9 @@ func SoftsignGrad(scope *Scope, gradients tf.Output, features tf.Output) (backpr // // The polygamma function is defined as: // -// ``` -// \psi^{(n)}(x) = \frac{d^n}{dx^n} \psi(x) -// ``` +// +// \\(\psi^{(n)}(x) = \frac{d^n}{dx^n} \psi(x)\\) +// // where \\(\psi(x)\\) is the digamma function. func Polygamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { if scope.Err() != nil { @@ -15540,9 +15540,8 @@ func Any(scope *Scope, input tf.Output, reduction_indices tf.Output, optional .. // // The Hurwitz zeta function is defined as: // -// ``` -// \zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x} -// ``` +// +// \\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\\) func Zeta(scope *Scope, x tf.Output, q tf.Output) (z tf.Output) { if scope.Err() != nil { return @@ -18720,13 +18719,12 @@ func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { // // The upper regularized incomplete Gamma function is defined as: // -// ``` -// Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x) -// ``` +// \\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\\) +// // where -// ``` -// Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt -// ``` +// +// \\(Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt\\) +// // is the upper incomplete Gama function. // // Note, above `P(a, x)` (`Igamma`) is the lower regularized complete @@ -18749,13 +18747,13 @@ func Igammac(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { // // The lower regularized incomplete Gamma function is defined as: // -// ``` -// P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x) -// ``` +// +// \\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\) +// // where -// ``` -// gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt -// ``` +// +// \\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\) +// // is the lower incomplete Gamma function. // // Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete @@ -18778,14 +18776,14 @@ func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { // // The regularized incomplete beta integral is defined as: // -// ``` -// I_x(a, b) = \frac{B(x; a, b)}{B(a, b)} -// ``` +// +// \\(I_x(a, b) = \frac{B(x; a, b)}{B(a, b)}\\) +// // where // -// ``` -// B(x; a, b) = \int_0^x t^{a-1} (1 - t)^{b-1} dt -// ``` +// +// \\(B(x; a, b) = \int_0^x t^{a-1} (1 - t)^{b-1} dt\\) +// // // is the incomplete beta function and \\(B(a, b)\\) is the *complete* // beta function. From a44385dff5c27acd0251a324bc3d5064a25140a2 Mon Sep 17 00:00:00 2001 From: Olivia Nordquist <nolivia@google.com> Date: Mon, 1 May 2017 13:01:30 -0800 Subject: [PATCH 31/38] clarify tf.cond documentation Change: 154756751 --- tensorflow/python/ops/control_flow_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index ebe5259de5b..a22e4819d51 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -1680,7 +1680,7 @@ def _UnpackIfSingleton(res): def cond(pred, fn1, fn2, strict=False, name=None): - """Return either `fn1()` or `fn2()` based on the boolean predicate `pred`. + """Return `fn1()` if the boolean predicate `pred` is true else `fn2()`. `fn1` and `fn2` both return lists of output tensors. `fn1` and `fn2` must have the same non-zero number and type of outputs. From 0ab09d2dd09f66bf654934b935806147ef23d535 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 1 May 2017 13:17:37 -0800 Subject: [PATCH 32/38] Adds access to expected output type from OpOutputList. Change: 154759042 --- tensorflow/core/framework/op_kernel.h | 7 +++++ tensorflow/core/framework/op_kernel_test.cc | 30 +++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 48bb69cb4e4..d926d7db190 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -432,6 +432,7 @@ class OpOutputList { OpOutputList& operator=(const OpOutputList& other) = default; Tensor* operator[](int i); bool required(int i) const; + DataType expected_output_dtype(int i) const; Status allocate(int i, const TensorShape& shape, Tensor** output); void set(int i, const Tensor& tensor); void set_ref(int i, mutex* mu, Tensor* tensor_for_ref); @@ -1452,6 +1453,12 @@ inline bool OpOutputList::required(int i) const { return ctx_->output_required(start_ + i); } +inline DataType OpOutputList::expected_output_dtype(int i) const { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->expected_output_dtype(start_ + i); +} + inline Status OpOutputList::allocate(int i, const TensorShape& shape, Tensor** output) { DCHECK_GE(i, 0); diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index 1c561899159..46d4dbd86ac 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -613,6 +613,36 @@ TEST_F(OpKernelBuilderTest, BadConstraint) { error::INVALID_ARGUMENT); } +REGISTER_OP("ListOut").Output("a: int32").Output("b: T").Attr("T: list(type)"); +REGISTER_KERNEL_BUILDER(Name("ListOut").Device(tensorflow::DEVICE_CPU), + DummyKernel); + +TEST_F(OpKernelBuilderTest, OpOutputList) { + Env* env = Env::Default(); + OpKernelContext::Params params; + params.record_tensor_accesses = false; + std::unique_ptr<DummyDevice> device( + new DummyDevice(env, params.record_tensor_accesses)); + params.device = device.get(); + Status status; + std::unique_ptr<OpKernel> op(CreateOpKernel( + DEVICE_CPU, params.device, cpu_allocator(), + CreateNodeDef("ListOut", {"T|list(type)|[DT_FLOAT, DT_INT32]"}), + TF_GRAPH_DEF_VERSION, &status)); + EXPECT_TRUE(status.ok()) << status.ToString(); + params.op_kernel = op.get(); + gtl::InlinedVector<TensorValue, 4> inputs{}; + params.inputs = &inputs; + std::unique_ptr<OpKernelContext> ctx(new OpKernelContext(¶ms)); + + EXPECT_EQ(DT_INT32, ctx->expected_output_dtype(0)); + OpOutputList out_list; + EXPECT_FALSE(ctx->output_list("non_existent_output", &out_list).ok()); + ASSERT_TRUE(ctx->output_list("b", &out_list).ok()); + EXPECT_EQ(DT_FLOAT, out_list.expected_output_dtype(0)); + EXPECT_EQ(DT_INT32, out_list.expected_output_dtype(1)); +} + class GetAttrKernel : public ::tensorflow::OpKernel { public: explicit GetAttrKernel(OpKernelConstruction* context) : OpKernel(context) { From ecbbcc43fa4bc2c3303c44fdc824c98d69b3e774 Mon Sep 17 00:00:00 2001 From: Toby Boyd <tobyboyd@google.com> Date: Mon, 1 May 2017 13:29:30 -0800 Subject: [PATCH 33/38] First draft of benchmark results. Change: 154760723 --- tensorflow/docs_src/performance/benchmarks.md | 429 ++++++++++++++++++ 1 file changed, 429 insertions(+) create mode 100644 tensorflow/docs_src/performance/benchmarks.md diff --git a/tensorflow/docs_src/performance/benchmarks.md b/tensorflow/docs_src/performance/benchmarks.md new file mode 100644 index 00000000000..8c0cff138de --- /dev/null +++ b/tensorflow/docs_src/performance/benchmarks.md @@ -0,0 +1,429 @@ +# TensorFlow Performance Benchmarks + +## Overview + +A selection of image classification models were tested across multiple platforms +to create a point of reference for the TensorFlow community. The methodology, +links to the scripts, and commands to reproduce the results are in the +[appendix](#appendix). + +## Results for image classification models + +InceptionV3 ([arXiv:1512.00567](https://arxiv.org/abs/1512.00567)), +ResNet-50 ([arXiv:1512.03385](https://arxiv.org/abs/1512.03385)), +ResNet-152 ([arXiv:1512.03385](https://arxiv.org/abs/1512.03385)), VGG16 +([arXiv:1409.1556](https://arxiv.org/abs/1409.1556)), and +[AlexNet](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf) +were tested using the [ImageNet](http://www.image-net.org/) data set. Tests were +run on Google Compute Engine, Amazon Elastic Compute Cloud (Amazon EC2), and an +NVIDIA® DGX-1™. Most of the tests were run with both synthetic and real data. +Testing with synthetic data was done by using a `tf.Variable` set to the same +shape as the data expected by each model for ImageNet. We believe it is +important to include real data measurements when benchmarking a platform. This +load tests both the underlying hardware and the framework at preparing data for +actual training. We start with synthetic data to remove disk I/O as a variable +and to set a baseline. Real data is then used to verify that the TensorFlow +input pipeline and the underlying disk I/O are saturating the compute units. + +### Training with NVIDIA® DGX-1™ (NVIDIA® Tesla® P100) + +<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;"> + <img style="width:100%" src="../images/perf_summary_p100_single_server.png"> +</div> + +Details and additional results are in the [Details for NVIDIA® DGX-1™ (NVIDIA® +Tesla® P100)](#details-for-nvidia®-dgx-1™-nvidia®-tesla®-p100) section. + +### Training with NVIDIA® Tesla® K80 + +<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> + <img style="width:100%" src="../images/perf_summary_k80_single_server.png"> +</div> + +Details and additional results are in the [Details for Google Compute Engine +(NVIDIA® Tesla® K80)](#details-for-google-compute-engine-nvidia®-tesla®-k80) and +[Details for Amazon EC2 (NVIDIA® Tesla® +K80)](#details-for-amazon-ec2-nvidia®-tesla®-k80) sections. + +### Distributed training with NVIDIA® Tesla® K80 + +<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> + <img style="width:100%" src="../images/perf_summary_k80_aws_distributed.png"> +</div> + +Details and additional results are in the [Details for Amazon EC2 Distributed +(NVIDIA® Tesla® K80)](#details-for-amazon-ec2-distributed-nvidia®-tesla®-k80) +section. + +### Compare synthetic with real data training + +**NVIDIA® Tesla® P100** + +<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> + <img style="width:35%" src="../images/perf_summary_p100_data_compare_inceptionv3.png"> + <img style="width:35%" src="../images/perf_summary_p100_data_compare_resnet50.png"> +</div> + +**NVIDIA® Tesla® K80** + +<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> + <img style="width:35%" src="../images/perf_summary_k80_data_compare_inceptionv3.png"> + <img style="width:35%" src="../images/perf_summary_k80_data_compare_resnet50.png"> +</div> + +## Details for NVIDIA® DGX-1™ (NVIDIA® Tesla® P100) + +### Environment + +* **Instance type**: NVIDIA® DGX-1™ +* **GPU:** 8x NVIDIA® Tesla® P100 +* **OS:** Ubuntu 16.04 LTS with tests run via Docker +* **CUDA / cuDNN:** 8.0 / 5.1 +* **TensorFlow GitHub hash:** b1e174e +* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda + //tensorflow/tools/pip_package:build_pip_package` +* **Disk:** local SSD +* **DataSet:** ImageNet + +Batch size and optimizer used for each model. + + | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 +------------------ | ----------- | --------- | ---------- | ------- | ----- +Batch size per GPU | 64 | 64 | 64 | 512 | 64 +Optimizer | sgd | sgd | sgd | sgd | sgd + +Configuration used for each model. + +Model | variable_update | local_parameter_device +----------- | ---------------------- | ---------------------- +InceptionV3 | parameter_server | cpu +ResNet50 | parameter_server | cpu +ResNet152 | parameter_server | cpu +AlexNet | replicated (with NCCL) | n/a +VGG16 | replicated (with NCCL) | n/a + +### Results + +Batch size and optimizer used for each model are listed in the table below. + +<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> + <img style="width:100%" src="../images/perf_summary_p100_single_server.png"> +</div> + +<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> + <img style="width:35%" src="../images/perf_dgx1_synth_p100_single_server_scaling.png"> + <img style="width:35%" src="../images/perf_dgx1_real_p100_single_server_scaling.png"> +</div> + +**Training synthetic data** + +GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 +---- | ----------- | --------- | ---------- | ------- | ----- +1 | 142 | 238 | 95.6 | 2987 | 132 +2 | 284 | 479 | 187 | 5658 | 259 +4 | 569 | 948 | 374 | 10509 | 511 +8 | 1131 | 1886 | 744 | 17822 | 959 + +**Training real data** + +GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 +---- | ----------- | --------- | ---------- | ------- | ----- +1 | 142 | 239 | 95.5 | 2890 | 132 +2 | 278 | 468 | 187 | 4448 | 245 +4 | 551 | 938 | 373 | 7105 | 466 +8 | 1079 | 1802 | 721 | N/A | 794 + +Training AlexNet with real data on 8 GPUs was excluded from the graph and table +above due to it maxing out the input pipeline. + +## Details for Google Compute Engine (NVIDIA® Tesla® K80) + +### Environment + +* **Instance type**: n1-standard-32-k80x8 +* **GPU:** 8x NVIDIA® Tesla® K80 +* **OS:** Ubuntu 16.04 LTS +* **CUDA / cuDNN:** 8.0 / 5.1 +* **TensorFlow GitHub hash:** b1e174e +* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda + //tensorflow/tools/pip_package:build_pip_package` +* **Disk:** 1.7 TB Shared SSD persistent disk (800 MB/s) +* **DataSet:** ImageNet +* **Test Date:** April 2017 + +Batch size and optimizer used for each model are listed in the table below. In +addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were +tested with a batch size of 32. Those results are in the *other results* +section. + + | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 +------------------ | ----------- | --------- | ---------- | ------- | ----- +Batch size per GPU | 64 | 64 | 32 | 512 | 32 +Optimizer | sgd | sgd | sgd | sgd | sgd + +The configuration used for each model was `variable_update` equal to +`parameter_server` and `local_parameter_device` equal to `cpu`. + +### Results + +<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> + <img style="width:35%" src="../images/perf_gce_synth_k80_single_server_scaling.png"> + <img style="width:35%" src="../images/perf_gce_real_k80_single_server_scaling.png"> +</div> + +**Training synthetic data** + +GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 +---- | ----------- | --------- | ---------- | ------- | ----- +1 | 30.5 | 56.8 | 20.8 | 656 | 30.3 +2 | 57.8 | 107 | 39.1 | 1210 | 56.2 +4 | 116 | 212 | 77.2 | 2330 | 106 +8 | 227 | 419 | 151 | 4640 | 222 + +**Training real data** + +GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 +---- | ----------- | --------- | ---------- | ------- | ----- +1 | 30.5 | 56.7 | 20.7 | 639 | 30.2 +2 | 57.8 | 107 | 39 | 1136 | 55.5 +4 | 115 | 211 | 77.3 | 2067 | 106 +8 | 225 | 418 | 150 | 4056 | 213 + +### Other Results + +**Training synthetic data** + +GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32) +---- | --------------------------- | ------------------------- +1 | 29.3 | 53.9 +2 | 55.0 | 101 +4 | 109 | 200 +8 | 216 | 398 + +**Training real data** + +GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32) +---- | --------------------------- | ------------------------- +1 | 29.3 | 53.6 +2 | 55 | 102 +4 | 109 | 200 +8 | 215 | 387 + +## Details for Amazon EC2 (NVIDIA® Tesla® K80) + +### Environment + +* **Instance type**: p2.8xlarge +* **GPU:** 8x NVIDIA® Tesla® K80 +* **OS:** Ubuntu 16.04 LTS +* **CUDA / cuDNN:** 8.0 / 5.1 +* **TensorFlow GitHub hash:** b1e174e +* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda + //tensorflow/tools/pip_package:build_pip_package` +* **Disk:** 1TB Amazon EFS (burst 100 MiB/sec for 12 hours, continuous 50 + MiB/sec) +* **DataSet:** ImageNet +* **Test Date:** April 2017 + +Batch size and optimizer used for each model are listed in the table below. In +addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were +tested with a batch size of 32. Those results are in the *other results* +section. + + | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 +------------------ | ----------- | --------- | ---------- | ------- | ----- +Batch size per GPU | 64 | 64 | 32 | 512 | 32 +Optimizer | sgd | sgd | sgd | sgd | sgd + +Configuration used for each model. + +Model | variable_update | local_parameter_device +----------- | ------------------------- | ---------------------- +InceptionV3 | parameter_server | cpu +ResNet-50 | replicated (without NCCL) | gpu +ResNet-152 | replicated (without NCCL) | gpu +AlexNet | parameter_server | gpu +VGG16 | parameter_server | gpu + +### Results + +<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> + <img style="width:35%" src="../images/perf_aws_synth_k80_single_server_scaling.png"> + <img style="width:35%" src="../images/perf_aws_real_k80_single_server_scaling.png"> +</div> + +**Training synthetic data** + +GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 +---- | ----------- | --------- | ---------- | ------- | ----- +1 | 30.8 | 56.3 | 20.9 | 684 | 32.4 +2 | 58.7 | 108 | 39.3 | 1244 | 61.5 +4 | 117 | 217 | 79.1 | 2479 | 123 +8 | 230 | 419 | 156 | 4853 | 234 + +**Training real data** + +GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 +---- | ----------- | --------- | ---------- | ------- | ----- +1 | 30.5 | 56.0 | 20.6 | 674 | 32.0 +2 | 58.7 | 107 | 39.0 | 1227 | 61.0 +4 | 118 | 205 | 77.9 | 2201 | 120 +8 | 228 | 405 | 152 | N/A | 191 + +Training AlexNet with real data on 8 GPUs was excluded from the graph and table +above due to our EFS setup not providing enough throughput. + +### Other Results + +**Training synthetic data** + +GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32) +---- | --------------------------- | ------------------------- +1 | 29.9 | 53.5 +2 | 57.5 | 101 +4 | 114 | 202 +8 | 216 | 380 + +**Training real data** + +GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32) +---- | --------------------------- | ------------------------- +1 | 30.0 | 53.6 +2 | 57.5 | 101 +4 | 113 | 202 +8 | 212 | 379 + +## Details for Amazon EC2 Distributed (NVIDIA® Tesla® K80) + +### Environment + +* **Instance type**: p2.8xlarge +* **GPU:** 8x NVIDIA® Tesla® K80 +* **OS:** Ubuntu 16.04 LTS +* **CUDA / cuDNN:** 8.0 / 5.1 +* **TensorFlow GitHub hash:** b1e174e +* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda + //tensorflow/tools/pip_package:build_pip_package` +* **Disk:** 1.0 TB EFS (burst 100 MB/sec for 12 hours, continuous 50 MB/sec) +* **DataSet:** ImageNet +* **Test Date:** April 2017 + +The batch size and optimizer used for the tests are listed in the table. In +addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were +tested with a batch size of 32. Those results are in the *other results* +section. + + | InceptionV3 | ResNet-50 | ResNet-152 +------------------ | ----------- | --------- | ---------- +Batch size per GPU | 64 | 64 | 32 +Optimizer | sgd | sgd | sgd + +Configuration used for each model. + +Model | variable_update | local_parameter_device +----------- | ---------------------- | ---------------------- +InceptionV3 | distributed_replicated | n/a +ResNet-50 | distributed_replicated | n/a +ResNet-152 | distributed_replicated | n/a + +To simplify server setup, EC2 instances (p2.8xlarge) running worker servers also +ran parameter servers. Equal numbers of parameter servers and work servers were +used with the following exceptions: + +* InceptionV3: 8 instances / 6 parameter servers +* ResNet-50: (batch size 32) 8 instances / 4 parameter servers +* ResNet-152: 8 instances / 4 parameter servers + +### Results + +<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> + <img style="width:95%" src="../images/perf_summary_k80_aws_distributed.png"> +</div> + +<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> + <img style="width:70%" src="../images/perf_aws_synth_k80_distributed_scaling.png"> +</div> + +**Training synthetic data** + +GPUs | InceptionV3 | ResNet-50 | ResNet-152 +---- | ----------- | --------- | ---------- +1 | 29.7 | 55.0 | 19.8 +8 | 229 | 410 | 150 +16 | 459 | 825 | 300 +32 | 902 | 1468 | 575 +64 | 1783 | 3051 | 1004 + +### Other Results + +<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> + <img style="width:50%" src="../images/perf_aws_synth_k80_multi_server_batch32.png"> +</div> + +**Training synthetic data** + +GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32) +---- | --------------------------- | ------------------------- +1 | 29.2 | 53.0 +8 | 219 | 363 +16 | 427 | 719 +32 | 820 | 1265 +64 | 1608 | 2623 + +## Appendix + +### Executing benchmark tests + +The code for the benchmarks was created to be both used for benchmarking +TensorFlow as well as used as a tool to test hardware platforms. The benchmark +code includes modes such as `trivial` that run a virtually empty model that is +useful for testing the maximum possibly samples/sec for the input pipeline among +other things. Not only does this test TensorFlow but also the throughput of the +underlying systems. There are two ways to execute the benchmarks in +[tf_cnn_benchmarks.py](TODO: LINK TO GITHUB): + +1. Execute [tf_cnn_benchmarks.py](TODO: LINK TO GITHUB) directly +2. Utilize the [small wrapper](TODO: LINK TO GITHUB) that helps pick the + correct config + +The wrapper is suggested as a starting point. Then investigate the variety of +options available in `tf_cnn_benchmarks.py`. While the wrapper extensive +examples, below are a couple highlights. + +Run ResNet-50 on a single instance with 8 GPUs. The `system` argument is used to +determine the optimal configuration. The supported values are gce, aws, and +dgx1. If `system` is not passeed, the best config for the most widely available +hardware is used. + +```bash +python main.py --model=resnet50 --num_gpus=8 +python main.py --system=aws --model=resnet50 --num_gpus=8 +``` + +Run ResNet-50 on 2 hosts, e.g. host_0 (10.0.0.1) and host_1 (10.0.0.2), with 8 +GPUs each on aws. + +```bash +# Run the following commands on host_0 (10.0.0.1): + $ python main.py --system=aws --model=resnet50 --job_name=worker + --hosts=10.0.0.1,10.0.0.2 --task_index=0 + + $ python main.py --system=aws --model=resnet50 --job_name=ps + --hosts=10.0.0.1,10.0.0.2 --task_index=0 + +# Run the following commands on host_1 (10.0.0.2): + $ python main.py --system=aws --model=resnet50 --job_name=worker + --hosts=10.0.0.1,10.0.0.2 --task_index=1 + + $ python main.py --system=aws --model=resnet50 --job_name=ps + --hosts=10.0.0.1,10.0.0.2 --task_index=1 +``` + +### Methodology + +Unless otherwise stated, each test is run 5 times and then the times are +averaged together. GPUs are run in their default state on the given platform. +For NVIDIA® Tesla® K80 this means leaving on [GPU +Boost](https://devblogs.nvidia.com/parallelforall/increase-performance-gpu-boost-k80-autoboost/) +unless it has been turned off by the provider. For a given test, 10 warmup steps +are done and then the next 100 steps are averaged. From 68c514faa9470d2bd8aed797339f048c50ed6317 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo <ebrevdo@google.com> Date: Mon, 1 May 2017 13:44:57 -0800 Subject: [PATCH 34/38] [tf distributions] Move kullback_leibler, special_math, and Normal into TF core. Also, renamed distributions.kl to distributions.kl_divergence. In core, scipy can't be required for unit tests, so added some lazy importing and checking for scipy modules in the normal and special_math tests. Change: 154763133 --- .../python/kernel_tests/entropy_test.py | 4 +- .../kernel_tests/stochastic_tensor_test.py | 2 +- .../variational_inference_test.py | 10 +-- .../bayesflow/python/ops/entropy_impl.py | 5 +- .../python/ops/variational_inference_impl.py | 4 +- tensorflow/contrib/distributions/BUILD | 46 ----------- tensorflow/contrib/distributions/__init__.py | 6 +- .../python/kernel_tests/bernoulli_test.py | 4 +- .../python/kernel_tests/beta_test.py | 6 +- .../python/kernel_tests/categorical_test.py | 6 +- .../python/kernel_tests/gamma_test.py | 4 +- .../mvn_diag_plus_low_rank_test.py | 18 +++-- .../python/kernel_tests/mvn_tril_test.py | 8 +- .../kernel_tests/onehot_categorical_test.py | 6 +- .../distributions/python/ops/bernoulli.py | 2 +- .../contrib/distributions/python/ops/beta.py | 2 +- .../distributions/python/ops/categorical.py | 2 +- .../contrib/distributions/python/ops/gamma.py | 2 +- .../distributions/python/ops/laplace.py | 2 +- .../python/ops/mvn_linear_operator.py | 4 +- .../python/ops/normal_conjugate_posteriors.py | 11 ++- .../python/ops/onehot_categorical.py | 2 +- .../python/kernel_tests/distributions/BUILD | 46 +++++++++++ .../distributions}/kullback_leibler_test.py | 30 +++---- .../distributions}/normal_test.py | 80 +++++++++++++------ .../distributions}/special_math_test.py | 33 +++++++- .../ops/distributions}/kullback_leibler.py | 32 ++++---- .../ops/distributions}/normal.py | 4 +- .../ops/distributions}/special_math.py | 0 29 files changed, 223 insertions(+), 158 deletions(-) rename tensorflow/{contrib/distributions/python/kernel_tests => python/kernel_tests/distributions}/kullback_leibler_test.py (76%) rename tensorflow/{contrib/distributions/python/kernel_tests => python/kernel_tests/distributions}/normal_test.py (96%) rename tensorflow/{contrib/distributions/python/kernel_tests => python/kernel_tests/distributions}/special_math_test.py (95%) rename tensorflow/{contrib/distributions/python/ops => python/ops/distributions}/kullback_leibler.py (82%) rename tensorflow/{contrib/distributions/python/ops => python/ops/distributions}/normal.py (98%) rename tensorflow/{contrib/distributions/python/ops => python/ops/distributions}/special_math.py (100%) diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py index d98d4e737c3..6cdaa318705 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py @@ -95,7 +95,7 @@ class ElboRatioTest(test.TestCase): n=n_samples, form=entropy.ELBOForms.sample, seed=42) - actual_kl = distributions.kl(q, p) + actual_kl = distributions.kl_divergence(q, p) # Relative tolerance (rtol) chosen 2 times as large as minimim needed to # pass. @@ -123,7 +123,7 @@ class ElboRatioTest(test.TestCase): n=n_samples, form=entropy.ELBOForms.analytic_entropy, seed=42) - actual_kl = distributions.kl(q, p) + actual_kl = distributions.kl_divergence(q, p) # Relative tolerance (rtol) chosen 2 times as large as minimim needed to # pass. diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py index 859f0e4b17c..6d0cff46789 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py @@ -22,11 +22,11 @@ import numpy as np from tensorflow.contrib.bayesflow.python.ops import stochastic_gradient_estimators from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor_impl -from tensorflow.contrib.distributions.python.ops import normal from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions import normal from tensorflow.python.platform import test sge = stochastic_gradient_estimators diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py index a46d755897e..fff6b74b2ef 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py @@ -22,12 +22,12 @@ from tensorflow.contrib import distributions as distributions_lib from tensorflow.contrib import layers from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor from tensorflow.contrib.bayesflow.python.ops import variational_inference_impl -from tensorflow.contrib.distributions.python.ops import kullback_leibler -from tensorflow.contrib.distributions.python.ops import normal from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables +from tensorflow.python.ops.distributions import kullback_leibler +from tensorflow.python.ops.distributions import normal from tensorflow.python.platform import test st = stochastic_tensor @@ -68,7 +68,7 @@ class VariationalInferenceTest(test.TestCase): def testDefaultVariationalAndPrior(self): _, prior, variational, _, log_likelihood = mini_vae() elbo = vi.elbo(log_likelihood) - expected_elbo = log_likelihood - kullback_leibler.kl( + expected_elbo = log_likelihood - kullback_leibler.kl_divergence( variational.distribution, prior) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) @@ -80,7 +80,7 @@ class VariationalInferenceTest(test.TestCase): prior = normal.Normal(loc=3., scale=2.) elbo = vi.elbo( log_likelihood, variational_with_prior={variational: prior}) - expected_elbo = log_likelihood - kullback_leibler.kl( + expected_elbo = log_likelihood - kullback_leibler.kl_divergence( variational.distribution, prior) sess.run(variables.global_variables_initializer()) self.assertAllEqual(*sess.run([expected_elbo, elbo])) @@ -121,7 +121,7 @@ class VariationalInferenceTest(test.TestCase): # No analytic KL available between prior and variational distributions. with self.assertRaisesRegexp(NotImplementedError, "No KL"): - distributions.kl(variational.distribution, prior) + distributions.kl_divergence(variational.distribution, prior) elbo = vi.elbo( variational_with_prior={variational: prior}, diff --git a/tensorflow/contrib/bayesflow/python/ops/entropy_impl.py b/tensorflow/contrib/bayesflow/python/ops/entropy_impl.py index ef9fb730258..f155de5032b 100644 --- a/tensorflow/contrib/bayesflow/python/ops/entropy_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/entropy_impl.py @@ -84,8 +84,9 @@ def elbo_ratio(log_p, KL[q || p] = E[ Log[q(Z)] - Log[p(Z)] ] ``` - Note that if `p` is a `Distribution`, then `distributions.kl(q, p)` may be - defined and available as an exact result. + Note that if `p` is a `Distribution`, then + `distributions.kl_divergence(q, p)` may be defined and available as an + exact result. #### ELBO diff --git a/tensorflow/contrib/bayesflow/python/ops/variational_inference_impl.py b/tensorflow/contrib/bayesflow/python/ops/variational_inference_impl.py index 3d3533461ce..8d932a7c340 100644 --- a/tensorflow/contrib/bayesflow/python/ops/variational_inference_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/variational_inference_impl.py @@ -28,10 +28,10 @@ from __future__ import print_function from tensorflow.contrib.bayesflow.python.ops import stochastic_graph_impl as sg from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor_impl as st -from tensorflow.contrib.distributions.python.ops import kullback_leibler from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.platform import tf_logging as logging VI_PRIORS = "__vi_priors__" @@ -259,7 +259,7 @@ def _elbo(form, log_likelihood, log_joint, variational_with_prior, kl = None if log_joint is None and form in {ELBOForms.default, ELBOForms.analytic_kl}: try: - kl = kullback_leibler.kl(q, p) + kl = kullback_leibler.kl_divergence(q, p) logging.info("Using analytic KL between q:%s, p:%s", q, p) except NotImplementedError as e: if form == ELBOForms.analytic_kl: diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index b58bdc75918..1b9bd6ad91c 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -496,24 +496,6 @@ cuda_py_test( ], ) -cuda_py_test( - name = "normal_test", - size = "medium", - srcs = ["python/kernel_tests/normal_test.py"], - additional_deps = [ - ":distributions_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:nn_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:variables", - ], -) - cuda_py_test( name = "poisson_test", size = "small", @@ -614,18 +596,6 @@ cuda_py_test( ], ) -cuda_py_test( - name = "kullback_leibler_test", - size = "small", - srcs = ["python/kernel_tests/kullback_leibler_test.py"], - additional_deps = [ - ":distributions_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:platform_test", - ], -) - cuda_py_test( name = "normal_conjugate_posteriors_test", size = "small", @@ -751,22 +721,6 @@ cuda_py_test( tags = ["no_pip"], ) -cuda_py_test( - name = "special_math_test", - size = "medium", - srcs = ["python/kernel_tests/special_math_test.py"], - additional_deps = [ - ":distributions_py", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:platform_test", - "//tensorflow/python:variables", - ], -) - cuda_py_test( name = "distribution_util_test", size = "small", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 6d1a0a4390d..15e33c2c6f0 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -71,7 +71,7 @@ See the @{$python/contrib.distributions} guide. @@RelaxedOneHotCategorical ## Kullback-Leibler Divergence -@@kl +@@kl_divergence @@RegisterKL ## Helper Functions @@ -106,7 +106,6 @@ from tensorflow.contrib.distributions.python.ops.exponential import * from tensorflow.contrib.distributions.python.ops.gamma import * from tensorflow.contrib.distributions.python.ops.geometric import * from tensorflow.contrib.distributions.python.ops.inverse_gamma import * -from tensorflow.contrib.distributions.python.ops.kullback_leibler import * from tensorflow.contrib.distributions.python.ops.laplace import * from tensorflow.contrib.distributions.python.ops.logistic import * from tensorflow.contrib.distributions.python.ops.mixture import * @@ -115,7 +114,6 @@ from tensorflow.contrib.distributions.python.ops.mvn_diag import * from tensorflow.contrib.distributions.python.ops.mvn_diag_plus_low_rank import * from tensorflow.contrib.distributions.python.ops.mvn_tril import * from tensorflow.contrib.distributions.python.ops.negative_binomial import * -from tensorflow.contrib.distributions.python.ops.normal import * from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import * from tensorflow.contrib.distributions.python.ops.onehot_categorical import * from tensorflow.contrib.distributions.python.ops.poisson import * @@ -129,6 +127,8 @@ from tensorflow.contrib.distributions.python.ops.uniform import * from tensorflow.contrib.distributions.python.ops.wishart import * from tensorflow.python.ops.distributions.conditional_distribution import * from tensorflow.python.ops.distributions.distribution import * +from tensorflow.python.ops.distributions.kullback_leibler import * +from tensorflow.python.ops.distributions.normal import * # pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py index 87b2331a1d4..e8b0eb4eb86 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py @@ -21,11 +21,11 @@ from __future__ import print_function import numpy as np import scipy.special from tensorflow.contrib.distributions.python.ops import bernoulli -from tensorflow.contrib.distributions.python.ops import kullback_leibler from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.platform import test @@ -286,7 +286,7 @@ class BernoulliTest(test.TestCase): a = bernoulli.Bernoulli(probs=a_p) b = bernoulli.Bernoulli(probs=b_p) - kl = kullback_leibler.kl(a, b) + kl = kullback_leibler.kl_divergence(a, b) kl_val = sess.run(kl) kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py b/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py index f524986cec8..ec16a85991d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py @@ -20,13 +20,13 @@ import numpy as np from scipy import special from scipy import stats from tensorflow.contrib.distributions.python.ops import beta as beta_lib -from tensorflow.contrib.distributions.python.ops import kullback_leibler from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.platform import test @@ -349,13 +349,13 @@ class BetaTest(test.TestCase): for dist1 in [d1, d1_sp]: for dist2 in [d2, d2_sp]: - kl = kullback_leibler.kl(dist1, dist2) + kl = kullback_leibler.kl_divergence(dist1, dist2) kl_val = sess.run(kl) self.assertEqual(kl.get_shape(), shape) self.assertAllClose(kl_val, kl_expected) # Make sure KL(d1||d1) is 0 - kl_same = sess.run(kullback_leibler.kl(d1, d1)) + kl_same = sess.run(kullback_leibler.kl_divergence(d1, d1)) self.assertAllClose(kl_same, np.zeros_like(kl_expected)) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/categorical_test.py b/tensorflow/contrib/distributions/python/kernel_tests/categorical_test.py index 0b42581e79f..269c02ede3a 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/categorical_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/categorical_test.py @@ -21,7 +21,6 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.distributions.python.ops import categorical -from tensorflow.contrib.distributions.python.ops import kullback_leibler from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util @@ -30,6 +29,7 @@ from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.platform import test @@ -278,10 +278,10 @@ class CategoricalTest(test.TestCase): a = categorical.Categorical(logits=a_logits) b = categorical.Categorical(logits=b_logits) - kl = kullback_leibler.kl(a, b) + kl = kullback_leibler.kl_divergence(a, b) kl_val = sess.run(kl) # Make sure KL(a||a) is 0 - kl_same = sess.run(kullback_leibler.kl(a, a)) + kl_same = sess.run(kullback_leibler.kl_divergence(a, a)) prob_a = np_softmax(a_logits) prob_b = np_softmax(b_logits) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py b/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py index fd627102372..5ccf2308a51 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py @@ -22,12 +22,12 @@ from scipy import special from scipy import stats from tensorflow.contrib.distributions.python.ops import gamma as gamma_lib -from tensorflow.contrib.distributions.python.ops import kullback_leibler from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.platform import test @@ -345,7 +345,7 @@ class GammaTest(test.TestCase): g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1) x = g0.sample(int(1e4), seed=0) kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0) - kl_actual = kullback_leibler.kl(g0, g1) + kl_actual = kullback_leibler.kl_divergence(g0, g1) # Execute graph. [kl_sample_, kl_actual_] = sess.run([kl_sample, kl_actual]) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_plus_low_rank_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_plus_low_rank_test.py index 29cb3eb9b0d..e803830a4e3 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_plus_low_rank_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_plus_low_rank_test.py @@ -185,19 +185,19 @@ class MultivariateNormalDiagPlusLowRankTest(test.TestCase): sample_kl_identity = math_ops.reduce_mean( dist.log_prob(samps) - mvn_identity.log_prob(samps), 0) - analytical_kl_identity = ds.kl(dist, mvn_identity) + analytical_kl_identity = ds.kl_divergence(dist, mvn_identity) sample_kl_scaled = math_ops.reduce_mean( dist.log_prob(samps) - mvn_scaled.log_prob(samps), 0) - analytical_kl_scaled = ds.kl(dist, mvn_scaled) + analytical_kl_scaled = ds.kl_divergence(dist, mvn_scaled) sample_kl_diag = math_ops.reduce_mean( dist.log_prob(samps) - mvn_diag.log_prob(samps), 0) - analytical_kl_diag = ds.kl(dist, mvn_diag) + analytical_kl_diag = ds.kl_divergence(dist, mvn_diag) sample_kl_chol = math_ops.reduce_mean( dist.log_prob(samps) - mvn_chol.log_prob(samps), 0) - analytical_kl_chol = ds.kl(dist, mvn_chol) + analytical_kl_chol = ds.kl_divergence(dist, mvn_chol) n = int(10e3) baseline = ds.MultivariateNormalDiag( @@ -208,19 +208,21 @@ class MultivariateNormalDiagPlusLowRankTest(test.TestCase): sample_kl_identity_diag_baseline = math_ops.reduce_mean( baseline.log_prob(samps) - mvn_identity.log_prob(samps), 0) - analytical_kl_identity_diag_baseline = ds.kl(baseline, mvn_identity) + analytical_kl_identity_diag_baseline = ds.kl_divergence( + baseline, mvn_identity) sample_kl_scaled_diag_baseline = math_ops.reduce_mean( baseline.log_prob(samps) - mvn_scaled.log_prob(samps), 0) - analytical_kl_scaled_diag_baseline = ds.kl(baseline, mvn_scaled) + analytical_kl_scaled_diag_baseline = ds.kl_divergence( + baseline, mvn_scaled) sample_kl_diag_diag_baseline = math_ops.reduce_mean( baseline.log_prob(samps) - mvn_diag.log_prob(samps), 0) - analytical_kl_diag_diag_baseline = ds.kl(baseline, mvn_diag) + analytical_kl_diag_diag_baseline = ds.kl_divergence(baseline, mvn_diag) sample_kl_chol_diag_baseline = math_ops.reduce_mean( baseline.log_prob(samps) - mvn_chol.log_prob(samps), 0) - analytical_kl_chol_diag_baseline = ds.kl(baseline, mvn_chol) + analytical_kl_chol_diag_baseline = ds.kl_divergence(baseline, mvn_chol) [ sample_mean_, diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py index dd7283bb57f..11f632f3f44 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py @@ -241,7 +241,7 @@ class MultivariateNormalTriLTest(test.TestCase): scale_tril=np.linalg.cholesky(sigma_b), validate_args=True) - kl = ds.kl(mvn_a, mvn_b) + kl = ds.kl_divergence(mvn_a, mvn_b) self.assertEqual(batch_shape, kl.get_shape()) kl_v = kl.eval() @@ -263,7 +263,7 @@ class MultivariateNormalTriLTest(test.TestCase): scale_tril=np.linalg.cholesky(sigma_b), validate_args=True) - kl = ds.kl(mvn_a, mvn_b) + kl = ds.kl_divergence(mvn_a, mvn_b) self.assertEqual(batch_shape, kl.get_shape()) kl_v = kl.eval() @@ -285,7 +285,7 @@ class MultivariateNormalTriLTest(test.TestCase): validate_args=True) # Should be zero since KL(p || p) = =. - kl = ds.kl(mvn_a, mvn_a) + kl = ds.kl_divergence(mvn_a, mvn_a) self.assertEqual(batch_shape, kl.get_shape()) kl_v = kl.eval() @@ -323,7 +323,7 @@ class MultivariateNormalTriLTest(test.TestCase): sample_kl_chol = math_ops.reduce_mean( dist.log_prob(samps) - mvn_chol.log_prob(samps), 0) - analytical_kl_chol = ds.kl(dist, mvn_chol) + analytical_kl_chol = ds.kl_divergence(dist, mvn_chol) scale = dist.scale.to_dense() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/onehot_categorical_test.py b/tensorflow/contrib/distributions/python/kernel_tests/onehot_categorical_test.py index 56ad4a081bc..111f88eeb50 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/onehot_categorical_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/onehot_categorical_test.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.distributions.python.ops import kullback_leibler from tensorflow.contrib.distributions.python.ops import onehot_categorical from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -27,6 +26,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.platform import test @@ -178,8 +178,8 @@ class OneHotCategoricalTest(test.TestCase): kl_expected = np.sum( prob_p * (np.log(prob_p) - np.log(prob_q)), axis=-1) - kl_actual = kullback_leibler.kl(p, q) - kl_same = kullback_leibler.kl(p, p) + kl_actual = kullback_leibler.kl_divergence(p, q) + kl_same = kullback_leibler.kl_divergence(p, p) x = p.sample(int(2e4), seed=0) x = math_ops.cast(x, dtype=dtypes.float32) # Compute empirical KL(p||q). diff --git a/tensorflow/contrib/distributions/python/ops/bernoulli.py b/tensorflow/contrib/distributions/python/ops/bernoulli.py index 7c29b7cdb2a..3281b57e83e 100644 --- a/tensorflow/contrib/distributions/python/ops/bernoulli.py +++ b/tensorflow/contrib/distributions/python/ops/bernoulli.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distributions.python.ops import kullback_leibler from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -29,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import util as distribution_util diff --git a/tensorflow/contrib/distributions/python/ops/beta.py b/tensorflow/contrib/distributions/python/ops/beta.py index 2866a7d9ff2..2b93478cdf9 100644 --- a/tensorflow/contrib/distributions/python/ops/beta.py +++ b/tensorflow/contrib/distributions/python/ops/beta.py @@ -20,7 +20,6 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.distributions.python.ops import kullback_leibler from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -32,6 +31,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import util as distribution_util diff --git a/tensorflow/contrib/distributions/python/ops/categorical.py b/tensorflow/contrib/distributions/python/ops/categorical.py index 5848c067690..1b74c2f0ca7 100644 --- a/tensorflow/contrib/distributions/python/ops/categorical.py +++ b/tensorflow/contrib/distributions/python/ops/categorical.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distributions.python.ops import kullback_leibler from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -28,6 +27,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import util as distribution_util diff --git a/tensorflow/contrib/distributions/python/ops/gamma.py b/tensorflow/contrib/distributions/python/ops/gamma.py index 255e8026e87..4ac2b9b4ef8 100644 --- a/tensorflow/contrib/distributions/python/ops/gamma.py +++ b/tensorflow/contrib/distributions/python/ops/gamma.py @@ -20,7 +20,6 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.distributions.python.ops import kullback_leibler from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -32,6 +31,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import util as distribution_util diff --git a/tensorflow/contrib/distributions/python/ops/laplace.py b/tensorflow/contrib/distributions/python/ops/laplace.py index 0a1e3990c4c..5c964ff78a5 100644 --- a/tensorflow/contrib/distributions/python/ops/laplace.py +++ b/tensorflow/contrib/distributions/python/ops/laplace.py @@ -22,7 +22,6 @@ import math import numpy as np -from tensorflow.contrib.distributions.python.ops import special_math from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -33,6 +32,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import special_math __all__ = [ diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index 5e879b4b223..a66eb1674ca 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -20,8 +20,6 @@ from __future__ import print_function from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import bijectors -from tensorflow.contrib.distributions.python.ops import kullback_leibler -from tensorflow.contrib.distributions.python.ops import normal from tensorflow.contrib.distributions.python.ops import transformed_distribution from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -29,6 +27,8 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import kullback_leibler +from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.distributions import util as distribution_util diff --git a/tensorflow/contrib/distributions/python/ops/normal_conjugate_posteriors.py b/tensorflow/contrib/distributions/python/ops/normal_conjugate_posteriors.py index 10e934326a1..4025285780b 100644 --- a/tensorflow/contrib/distributions/python/ops/normal_conjugate_posteriors.py +++ b/tensorflow/contrib/distributions/python/ops/normal_conjugate_posteriors.py @@ -18,9 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distributions.python.ops.normal import Normal # pylint: disable=line-too-long - from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import normal def normal_conjugates_known_scale_posterior(prior, scale, s, n): @@ -65,7 +64,7 @@ def normal_conjugates_known_scale_posterior(prior, scale, s, n): TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a Normal object. """ - if not isinstance(prior, Normal): + if not isinstance(prior, normal.Normal): raise TypeError("Expected prior to be an instance of type Normal") if s.dtype != prior.dtype: @@ -77,7 +76,7 @@ def normal_conjugates_known_scale_posterior(prior, scale, s, n): scale0_2 = math_ops.square(prior.scale) scale_2 = math_ops.square(scale) scalep_2 = 1.0/(1/scale0_2 + n/scale_2) - return Normal( + return normal.Normal( loc=(prior.loc/scale0_2 + s/scale_2) * scalep_2, scale=math_ops.sqrt(scalep_2)) @@ -131,7 +130,7 @@ def normal_conjugates_known_scale_predictive(prior, scale, s, n): TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a Normal object. """ - if not isinstance(prior, Normal): + if not isinstance(prior, normal.Normal): raise TypeError("Expected prior to be an instance of type Normal") if s.dtype != prior.dtype: @@ -143,6 +142,6 @@ def normal_conjugates_known_scale_predictive(prior, scale, s, n): scale0_2 = math_ops.square(prior.scale) scale_2 = math_ops.square(scale) scalep_2 = 1.0/(1/scale0_2 + n/scale_2) - return Normal( + return normal.Normal( loc=(prior.loc/scale0_2 + s/scale_2) * scalep_2, scale=math_ops.sqrt(scalep_2 + scale_2)) diff --git a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py index 378ddc491f3..b76cebf79fa 100644 --- a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distributions.python.ops import kullback_leibler from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -28,6 +27,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import util as distribution_util diff --git a/tensorflow/python/kernel_tests/distributions/BUILD b/tensorflow/python/kernel_tests/distributions/BUILD index 63e213b5e07..3c1a4d5125c 100644 --- a/tensorflow/python/kernel_tests/distributions/BUILD +++ b/tensorflow/python/kernel_tests/distributions/BUILD @@ -29,6 +29,52 @@ cuda_py_test( ], ) +cuda_py_test( + name = "kullback_leibler_test", + size = "small", + srcs = ["kullback_leibler_test.py"], + additional_deps = [ + "//tensorflow/python/ops/distributions", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_test( + name = "normal_test", + size = "medium", + srcs = ["normal_test.py"], + additional_deps = [ + "//tensorflow/python/ops/distributions", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], +) + +cuda_py_test( + name = "special_math_test", + size = "medium", + srcs = ["special_math_test.py"], + additional_deps = [ + "//tensorflow/python/ops/distributions", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py b/tensorflow/python/kernel_tests/distributions/kullback_leibler_test.py similarity index 76% rename from tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py rename to tensorflow/python/kernel_tests/distributions/kullback_leibler_test.py index 6b3d886e01b..b1d8da77161 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py +++ b/tensorflow/python/kernel_tests/distributions/kullback_leibler_test.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distributions.python.ops import kullback_leibler -from tensorflow.contrib.distributions.python.ops import normal from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions import kullback_leibler +from tensorflow.python.ops.distributions import normal from tensorflow.python.platform import test # pylint: disable=protected-access @@ -43,7 +43,7 @@ class KLTest(test.TestCase): return name a = MyDist(loc=0.0, scale=1.0) - self.assertEqual("OK", kullback_leibler.kl(a, a, name="OK")) + self.assertEqual("OK", kullback_leibler.kl_divergence(a, a, name="OK")) def testDomainErrorExceptions(self): @@ -60,11 +60,11 @@ class KLTest(test.TestCase): with self.test_session(): a = MyDistException(loc=0.0, scale=1.0) - kl = kullback_leibler.kl(a, a, allow_nan_stats=False) + kl = kullback_leibler.kl_divergence(a, a, allow_nan_stats=False) with self.assertRaisesOpError( "KL calculation between .* and .* returned NaN values"): kl.eval() - kl_ok = kullback_leibler.kl(a, a) + kl_ok = kullback_leibler.kl_divergence(a, a) self.assertAllEqual([float("nan")], kl_ok.eval()) def testRegistrationFailures(self): @@ -116,16 +116,16 @@ class KLTest(test.TestCase): sub2 = Sub2(loc=0.0, scale=1.0) sub11 = Sub11(loc=0.0, scale=1.0) - self.assertEqual("sub1-1", kullback_leibler.kl(sub1, sub1)) - self.assertEqual("sub1-2", kullback_leibler.kl(sub1, sub2)) - self.assertEqual("sub2-1", kullback_leibler.kl(sub2, sub1)) - self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub11)) - self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub1)) - self.assertEqual("sub1-2", kullback_leibler.kl(sub11, sub2)) - self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub1)) - self.assertEqual("sub1-2", kullback_leibler.kl(sub11, sub2)) - self.assertEqual("sub2-1", kullback_leibler.kl(sub2, sub11)) - self.assertEqual("sub1-1", kullback_leibler.kl(sub1, sub11)) + self.assertEqual("sub1-1", kullback_leibler.kl_divergence(sub1, sub1)) + self.assertEqual("sub1-2", kullback_leibler.kl_divergence(sub1, sub2)) + self.assertEqual("sub2-1", kullback_leibler.kl_divergence(sub2, sub1)) + self.assertEqual("sub1-1", kullback_leibler.kl_divergence(sub11, sub11)) + self.assertEqual("sub1-1", kullback_leibler.kl_divergence(sub11, sub1)) + self.assertEqual("sub1-2", kullback_leibler.kl_divergence(sub11, sub2)) + self.assertEqual("sub1-1", kullback_leibler.kl_divergence(sub11, sub1)) + self.assertEqual("sub1-2", kullback_leibler.kl_divergence(sub11, sub2)) + self.assertEqual("sub2-1", kullback_leibler.kl_divergence(sub2, sub11)) + self.assertEqual("sub1-1", kullback_leibler.kl_divergence(sub1, sub11)) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/normal_test.py b/tensorflow/python/kernel_tests/distributions/normal_test.py similarity index 96% rename from tensorflow/contrib/distributions/python/kernel_tests/normal_test.py rename to tensorflow/python/kernel_tests/distributions/normal_test.py index bda06df0f79..07c7d6d11d0 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/normal_test.py +++ b/tensorflow/python/kernel_tests/distributions/normal_test.py @@ -18,13 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import importlib import math import numpy as np -from scipy import stats -from tensorflow.contrib.distributions.python.ops import kullback_leibler -from tensorflow.contrib.distributions.python.ops import normal as normal_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -33,7 +31,21 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import nn_ops from tensorflow.python.ops import variables +from tensorflow.python.ops.distributions import kullback_leibler +from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + + +def try_import(name): # pylint: disable=invalid-name + module = None + try: + module = importlib.import_module(name) + except ImportError as e: + tf_logging.warning("Could not import %s: %s" % (name, str(e))) + return module + +stats = try_import("scipy.stats") class NormalTest(test.TestCase): @@ -90,10 +102,8 @@ class NormalTest(test.TestCase): sigma = constant_op.constant([math.sqrt(10.0)] * batch_size) x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32) normal = normal_lib.Normal(loc=mu, scale=sigma) - expected_log_pdf = stats.norm(mu.eval(), sigma.eval()).logpdf(x) log_pdf = normal.log_prob(x) - self.assertAllClose(expected_log_pdf, log_pdf.eval()) self.assertAllEqual(normal.batch_shape_tensor().eval(), log_pdf.get_shape()) self.assertAllEqual(normal.batch_shape_tensor().eval(), @@ -102,12 +112,17 @@ class NormalTest(test.TestCase): self.assertAllEqual(normal.batch_shape, log_pdf.eval().shape) pdf = normal.prob(x) - self.assertAllClose(np.exp(expected_log_pdf), pdf.eval()) self.assertAllEqual(normal.batch_shape_tensor().eval(), pdf.get_shape()) self.assertAllEqual(normal.batch_shape_tensor().eval(), pdf.eval().shape) self.assertAllEqual(normal.batch_shape, pdf.get_shape()) self.assertAllEqual(normal.batch_shape, pdf.eval().shape) + if not stats: + return + expected_log_pdf = stats.norm(mu.eval(), sigma.eval()).logpdf(x) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + self.assertAllClose(np.exp(expected_log_pdf), pdf.eval()) + def testNormalLogPDFMultidimensional(self): with self.test_session(): batch_size = 6 @@ -116,12 +131,10 @@ class NormalTest(test.TestCase): batch_size) x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T normal = normal_lib.Normal(loc=mu, scale=sigma) - expected_log_pdf = stats.norm(mu.eval(), sigma.eval()).logpdf(x) log_pdf = normal.log_prob(x) log_pdf_values = log_pdf.eval() self.assertEqual(log_pdf.get_shape(), (6, 2)) - self.assertAllClose(expected_log_pdf, log_pdf_values) self.assertAllEqual(normal.batch_shape_tensor().eval(), log_pdf.get_shape()) self.assertAllEqual(normal.batch_shape_tensor().eval(), @@ -132,12 +145,17 @@ class NormalTest(test.TestCase): pdf = normal.prob(x) pdf_values = pdf.eval() self.assertEqual(pdf.get_shape(), (6, 2)) - self.assertAllClose(np.exp(expected_log_pdf), pdf_values) self.assertAllEqual(normal.batch_shape_tensor().eval(), pdf.get_shape()) self.assertAllEqual(normal.batch_shape_tensor().eval(), pdf_values.shape) self.assertAllEqual(normal.batch_shape, pdf.get_shape()) self.assertAllEqual(normal.batch_shape, pdf_values.shape) + if not stats: + return + expected_log_pdf = stats.norm(mu.eval(), sigma.eval()).logpdf(x) + self.assertAllClose(expected_log_pdf, log_pdf_values) + self.assertAllClose(np.exp(expected_log_pdf), pdf_values) + def testNormalCDF(self): with self.test_session(): batch_size = 50 @@ -146,14 +164,15 @@ class NormalTest(test.TestCase): x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) normal = normal_lib.Normal(loc=mu, scale=sigma) - expected_cdf = stats.norm(mu, sigma).cdf(x) - cdf = normal.cdf(x) - self.assertAllClose(expected_cdf, cdf.eval(), atol=0) self.assertAllEqual(normal.batch_shape_tensor().eval(), cdf.get_shape()) self.assertAllEqual(normal.batch_shape_tensor().eval(), cdf.eval().shape) self.assertAllEqual(normal.batch_shape, cdf.get_shape()) self.assertAllEqual(normal.batch_shape, cdf.eval().shape) + if not stats: + return + expected_cdf = stats.norm(mu, sigma).cdf(x) + self.assertAllClose(expected_cdf, cdf.eval(), atol=0) def testNormalSurvivalFunction(self): with self.test_session(): @@ -163,14 +182,16 @@ class NormalTest(test.TestCase): x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) normal = normal_lib.Normal(loc=mu, scale=sigma) - expected_sf = stats.norm(mu, sigma).sf(x) sf = normal.survival_function(x) - self.assertAllClose(expected_sf, sf.eval(), atol=0) self.assertAllEqual(normal.batch_shape_tensor().eval(), sf.get_shape()) self.assertAllEqual(normal.batch_shape_tensor().eval(), sf.eval().shape) self.assertAllEqual(normal.batch_shape, sf.get_shape()) self.assertAllEqual(normal.batch_shape, sf.eval().shape) + if not stats: + return + expected_sf = stats.norm(mu, sigma).sf(x) + self.assertAllClose(expected_sf, sf.eval(), atol=0) def testNormalLogCDF(self): with self.test_session(): @@ -180,15 +201,18 @@ class NormalTest(test.TestCase): x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64) normal = normal_lib.Normal(loc=mu, scale=sigma) - expected_cdf = stats.norm(mu, sigma).logcdf(x) cdf = normal.log_cdf(x) - self.assertAllClose(expected_cdf, cdf.eval(), atol=0, rtol=1e-5) self.assertAllEqual(normal.batch_shape_tensor().eval(), cdf.get_shape()) self.assertAllEqual(normal.batch_shape_tensor().eval(), cdf.eval().shape) self.assertAllEqual(normal.batch_shape, cdf.get_shape()) self.assertAllEqual(normal.batch_shape, cdf.eval().shape) + if not stats: + return + expected_cdf = stats.norm(mu, sigma).logcdf(x) + self.assertAllClose(expected_cdf, cdf.eval(), atol=0, rtol=1e-5) + def testFiniteGradientAtDifficultPoints(self): for dtype in [np.float32, np.float64]: g = ops.Graph() @@ -217,15 +241,18 @@ class NormalTest(test.TestCase): x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64) normal = normal_lib.Normal(loc=mu, scale=sigma) - expected_sf = stats.norm(mu, sigma).logsf(x) sf = normal.log_survival_function(x) - self.assertAllClose(expected_sf, sf.eval(), atol=0, rtol=1e-5) self.assertAllEqual(normal.batch_shape_tensor().eval(), sf.get_shape()) self.assertAllEqual(normal.batch_shape_tensor().eval(), sf.eval().shape) self.assertAllEqual(normal.batch_shape, sf.get_shape()) self.assertAllEqual(normal.batch_shape, sf.eval().shape) + if not stats: + return + expected_sf = stats.norm(mu, sigma).logsf(x) + self.assertAllClose(expected_sf, sf.eval(), atol=0, rtol=1e-5) + def testNormalEntropyWithScalarInputs(self): # Scipy.stats.norm cannot deal with the shapes in the other test. with self.test_session(): @@ -233,16 +260,18 @@ class NormalTest(test.TestCase): sigma_v = 4.56 normal = normal_lib.Normal(loc=mu_v, scale=sigma_v) - # scipy.stats.norm cannot deal with these shapes. - expected_entropy = stats.norm(mu_v, sigma_v).entropy() entropy = normal.entropy() - self.assertAllClose(expected_entropy, entropy.eval()) self.assertAllEqual(normal.batch_shape_tensor().eval(), entropy.get_shape()) self.assertAllEqual(normal.batch_shape_tensor().eval(), entropy.eval().shape) self.assertAllEqual(normal.batch_shape, entropy.get_shape()) self.assertAllEqual(normal.batch_shape, entropy.eval().shape) + # scipy.stats.norm cannot deal with these shapes. + if not stats: + return + expected_entropy = stats.norm(mu_v, sigma_v).entropy() + self.assertAllClose(expected_entropy, entropy.eval()) def testNormalEntropy(self): with self.test_session(): @@ -288,15 +317,18 @@ class NormalTest(test.TestCase): p = np.hstack((p, np.exp(-33), 1. - np.exp(-33))) normal = normal_lib.Normal(loc=mu, scale=sigma) - expected_x = stats.norm(mu, sigma).ppf(p) x = normal.quantile(p) - self.assertAllClose(expected_x, x.eval(), atol=0.) self.assertAllEqual(normal.batch_shape_tensor().eval(), x.get_shape()) self.assertAllEqual(normal.batch_shape_tensor().eval(), x.eval().shape) self.assertAllEqual(normal.batch_shape, x.get_shape()) self.assertAllEqual(normal.batch_shape, x.eval().shape) + if not stats: + return + expected_x = stats.norm(mu, sigma).ppf(p) + self.assertAllClose(expected_x, x.eval(), atol=0.) + def _baseQuantileFiniteGradientAtDifficultPoints(self, dtype): g = ops.Graph() with g.as_default(): @@ -450,7 +482,7 @@ class NormalTest(test.TestCase): n_a = normal_lib.Normal(loc=mu_a, scale=sigma_a) n_b = normal_lib.Normal(loc=mu_b, scale=sigma_b) - kl = kullback_leibler.kl(n_a, n_b) + kl = kullback_leibler.kl_divergence(n_a, n_b) kl_val = sess.run(kl) kl_expected = ((mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * ( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/special_math_test.py b/tensorflow/python/kernel_tests/distributions/special_math_test.py similarity index 95% rename from tensorflow/contrib/distributions/python/kernel_tests/special_math_test.py rename to tensorflow/python/kernel_tests/distributions/special_math_test.py index 6ccbeb71222..dc462bae56b 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/special_math_test.py +++ b/tensorflow/python/kernel_tests/distributions/special_math_test.py @@ -19,18 +19,30 @@ from __future__ import division from __future__ import print_function import collections +import importlib import numpy as np -from scipy import special -from scipy import stats -from tensorflow.contrib.distributions.python.ops import special_math from tensorflow.python.framework import ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import variables +from tensorflow.python.ops.distributions import special_math from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + +def try_import(name): # pylint: disable=invalid-name + module = None + try: + module = importlib.import_module(name) + except ImportError as e: + tf_logging.warning("Could not import %s: %s" % (name, str(e))) + return module + + +special = try_import("scipy.special") +stats = try_import("scipy.stats") sm = special_math @@ -66,6 +78,9 @@ class NdtriTest(test.TestCase): def testNdtri(self): """Verifies that ndtri computation is correct.""" with self.test_session(): + if not special: + return + p = np.linspace(0., 1.0, 50).astype(np.float64) # Quantile performs piecewise rational approximation so adding some # special input values to make sure we hit all the pieces. @@ -113,6 +128,9 @@ class NdtrTest(test.TestCase): self._test_grid_no_log(dtype, grid_spec, error_spec) def _test_grid_log(self, dtype, grid_spec, error_spec): + if not special: + return + with self.test_session(): grid = _make_grid(dtype, grid_spec) actual = sm.log_ndtr(grid).eval() @@ -137,6 +155,9 @@ class NdtrTest(test.TestCase): atol=error_spec.atol) def _test_grid_no_log(self, dtype, grid_spec, error_spec): + if not special: + return + with self.test_session(): grid = _make_grid(dtype, grid_spec) actual = sm.ndtr(grid).eval() @@ -267,6 +288,9 @@ class NdtrGradientTest(test.TestCase): self.assert_all_true(np.isfinite(grad_eval)) # Versus scipy. + if not (special and stats): + return + expected = stats.norm.pdf(raw_grid) if self._use_log: expected /= special.ndtr(raw_grid) @@ -323,6 +347,9 @@ class LogCDFLaplaceTest(test.TestCase): _check_strictly_increasing(actual) # Versus scipy. + if not stats: + return + scipy_dist = stats.laplace(loc=0., scale=1.) expected = scipy_dist.logcdf(grid.astype(scipy_dtype)) self.assertAllClose( diff --git a/tensorflow/contrib/distributions/python/ops/kullback_leibler.py b/tensorflow/python/ops/distributions/kullback_leibler.py similarity index 82% rename from tensorflow/contrib/distributions/python/ops/kullback_leibler.py rename to tensorflow/python/ops/distributions/kullback_leibler.py index 335fe7a5e2a..9770d82bd83 100644 --- a/tensorflow/contrib/distributions/python/ops/kullback_leibler.py +++ b/tensorflow/python/ops/distributions/kullback_leibler.py @@ -44,11 +44,13 @@ def _registered_kl(type_a, type_b): return kl_fn -def kl(dist_a, dist_b, allow_nan_stats=True, name=None): - """Get the KL-divergence KL(dist_a || dist_b). +def kl_divergence(distribution_a, distribution_b, + allow_nan_stats=True, name=None): + """Get the KL-divergence KL(distribution_a || distribution_b). - If there is no KL method registered specifically for `type(dist_a)` and - `type(dist_b)`, then the class hierarchies of these types are searched. + If there is no KL method registered specifically for `type(distribution_a)` + and `type(distribution_b)`, then the class hierarchies of these types are + searched. If one KL method is registered between any pairs of classes in these two parent hierarchies, it is used. @@ -58,11 +60,11 @@ def kl(dist_a, dist_b, allow_nan_stats=True, name=None): If more than one such shortest path exists, the first method identified in the search is used (favoring a shorter MRO distance to - `type(dist_a)`). + `type(distribution_a)`). Args: - dist_a: The first distribution. - dist_b: The second distribution. + distribution_a: The first distribution. + distribution_b: The second distribution. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised @@ -70,20 +72,22 @@ def kl(dist_a, dist_b, allow_nan_stats=True, name=None): name: Python `str` name prefixed to Ops created by this class. Returns: - A Tensor with the batchwise KL-divergence between dist_a and dist_b. + A Tensor with the batchwise KL-divergence between `distribution_a` + and `distribution_b`. Raises: NotImplementedError: If no KL method is defined for distribution types - of dist_a and dist_b. + of `distribution_a` and `distribution_b`. """ - kl_fn = _registered_kl(type(dist_a), type(dist_b)) + kl_fn = _registered_kl(type(distribution_a), type(distribution_b)) if kl_fn is None: raise NotImplementedError( - "No KL(dist_a || dist_b) registered for dist_a type %s and dist_b " - "type %s" % (type(dist_a).__name__, type(dist_b).__name__)) + "No KL(distribution_a || distribution_b) registered for distribution_a " + "type %s and distribution_b type %s" + % (type(distribution_a).__name__, type(distribution_b).__name__)) with ops.name_scope("KullbackLeibler"): - kl_t = kl_fn(dist_a, dist_b, name=name) + kl_t = kl_fn(distribution_a, distribution_b, name=name) if allow_nan_stats: return kl_t @@ -96,7 +100,7 @@ def kl(dist_a, dist_b, allow_nan_stats=True, name=None): math_ops.reduce_any(math_ops.is_nan(kl_t))), ["KL calculation between %s and %s returned NaN values " "(and was called with allow_nan_stats=False). Values:" - % (dist_a.name, dist_b.name), kl_t])]): + % (distribution_a.name, distribution_b.name), kl_t])]): return array_ops.identity(kl_t, name="checked_kl") diff --git a/tensorflow/contrib/distributions/python/ops/normal.py b/tensorflow/python/ops/distributions/normal.py similarity index 98% rename from tensorflow/contrib/distributions/python/ops/normal.py rename to tensorflow/python/ops/distributions/normal.py index ed12439c44d..4c531b03788 100644 --- a/tensorflow/contrib/distributions/python/ops/normal.py +++ b/tensorflow/python/ops/distributions/normal.py @@ -20,8 +20,6 @@ from __future__ import print_function import math -from tensorflow.contrib.distributions.python.ops import kullback_leibler -from tensorflow.contrib.distributions.python.ops import special_math from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -32,6 +30,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import kullback_leibler +from tensorflow.python.ops.distributions import special_math __all__ = [ diff --git a/tensorflow/contrib/distributions/python/ops/special_math.py b/tensorflow/python/ops/distributions/special_math.py similarity index 100% rename from tensorflow/contrib/distributions/python/ops/special_math.py rename to tensorflow/python/ops/distributions/special_math.py From 8a123f7d1bca579a1401a10804eb1a45e213d6bc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 1 May 2017 14:59:25 -0800 Subject: [PATCH 35/38] Add small test to check hinge_softness != 0 in softplus bijector. Change: 154773145 --- .../python/kernel_tests/bijectors/softplus_test.py | 6 ++++++ .../distributions/python/ops/bijectors/softplus_impl.py | 4 +++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py index e8abde40875..214b196b547 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py @@ -41,6 +41,12 @@ class SoftplusBijectorTest(test.TestCase): """Inverse log det jacobian, before being reduced.""" return -np.log(1 - np.exp(-y)) + def testHingeSoftnessZeroRaises(self): + with self.test_session(): + bijector = Softplus(event_ndims=0, hinge_softness=0., validate_args=True) + with self.assertRaisesOpError("must be non-zero"): + bijector.forward([1., 1.]).eval() + def testBijectorForwardInverseEventDimsZero(self): with self.test_session(): bijector = Softplus(event_ndims=0) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softplus_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/softplus_impl.py index baa41cfcbde..81957fcf789 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softplus_impl.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softplus_impl.py @@ -94,7 +94,9 @@ class Softplus(bijector.Bijector): if validate_args: nonzero_check = check_ops.assert_none_equal( ops.convert_to_tensor( - 0, dtype=self.hinge_softness.dtype), self.hinge_softness) + 0, dtype=self.hinge_softness.dtype), + self.hinge_softness, + message="hinge_softness must be non-zero") self._hinge_softness = control_flow_ops.with_dependencies( [nonzero_check], self.hinge_softness) From 7c561e09c05100fe68f00d66a2d27d1b490ee74e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 1 May 2017 15:11:37 -0800 Subject: [PATCH 36/38] Explain when callables passed to tf.cond & tf.while_loop are run. Rename the parameters to tf.cond. Change: 154774725 --- tensorflow/python/ops/control_flow_ops.py | 111 ++++++++++++------ .../python/ops/control_flow_ops_test.py | 63 ++++++++++ tensorflow/tools/api/golden/tensorflow.pbtxt | 2 +- 3 files changed, 137 insertions(+), 39 deletions(-) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index a22e4819d51..91be9ddbd78 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -71,6 +71,7 @@ from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops.gen_control_flow_ops import * # pylint: enable=wildcard-import from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import deprecation from tensorflow.python.util import nest from tensorflow.python.util import tf_should_use @@ -1679,14 +1680,20 @@ def _UnpackIfSingleton(res): return res -def cond(pred, fn1, fn2, strict=False, name=None): - """Return `fn1()` if the boolean predicate `pred` is true else `fn2()`. +# pylint: disable=g-doc-args +@deprecation.deprecated_args( + None, + "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.", + "fn1", "fn2") +def cond(pred, true_fn=None, false_fn=None, strict=False, name=None, + fn1=None, fn2=None): + """Return `true_fn()` if the predicate `pred` is true else `false_fn()`. - `fn1` and `fn2` both return lists of output tensors. `fn1` and `fn2` must have - the same non-zero number and type of outputs. + `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and + `false_fn` must have the same non-zero number and type of outputs. Note that the conditional execution applies only to the operations defined in - `fn1` and `fn2`. Consider the following simple program: + `true_fn` and `false_fn`. Consider the following simple program: ```python z = tf.multiply(a, b) @@ -1700,28 +1707,35 @@ def cond(pred, fn1, fn2, strict=False, name=None): Although this behavior is consistent with the dataflow model of TensorFlow, it has occasionally surprised some users who expected a lazier semantics. + Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the + call to `cond`, and not at all during `Session.run()`). `cond` + stitches together the graph fragments created during the `true_fn` and + `false_fn` calls with some additional graph nodes to ensure that the right + branch gets executed depending on the value of `pred`. + `tf.cond` supports nested structures as implemented in - `tensorflow.python.util.nest`. Both `fn1` and `fn2` must return the same - (possibly nested) value structure of lists, tuples, and/or named tuples. + `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the + same (possibly nested) value structure of lists, tuples, and/or named tuples. Singleton lists and tuples form the only exceptions to this: when returned by - `fn1` and/or `fn2`, they are implicitly unpacked to single values. This - behavior is disabled by passing `strict=True`. + `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. + This behavior is disabled by passing `strict=True`. Args: - pred: A scalar determining whether to return the result of `fn1` or `fn2`. - fn1: The callable to be performed if pred is true. - fn2: The callable to be performed if pred is false. + pred: A scalar determining whether to return the result of `true_fn` or + `false_fn`. + true_fn: The callable to be performed if pred is true. + false_fn: The callable to be performed if pred is false. strict: A boolean that enables/disables 'strict' mode; see above. name: Optional name prefix for the returned tensors. Returns: - Tensors returned by the call to either `fn1` or `fn2`. If the callables - return a singleton list, the element is extracted from the list. + Tensors returned by the call to either `true_fn` or `false_fn`. If the + callables return a singleton list, the element is extracted from the list. Raises: - TypeError: if `fn1` or `fn2` is not callable. - ValueError: if `fn1` and `fn2` do not return the same number of tensors, or - return tensors of different types. + TypeError: if `true_fn` or `false_fn` is not callable. + ValueError: if `true_fn` and `false_fn` do not return the same number of + tensors, or return tensors of different types. Example: @@ -1736,12 +1750,30 @@ def cond(pred, fn1, fn2, strict=False, name=None): ``` """ - with ops.name_scope(name, "cond", [pred]) as name: - if not callable(fn1): - raise TypeError("fn1 must be callable.") - if not callable(fn2): - raise TypeError("fn2 must be callable.") + # We needed to make true_fn/false_fn keyword arguments for + # backwards-compatibility. This check exists so that we can convert back to + # having them be positional arguments. + # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after + # `fn1` and `fn2` are deleted. + if fn1 is not None: + if true_fn is not None: + raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.") + true_fn = fn1 + elif true_fn is None: + raise TypeError("cond(): true_fn argument required") + if fn2 is not None: + if false_fn is not None: + raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.") + false_fn = fn2 + elif false_fn is None: + raise TypeError("cond(): false_fn argument required") + if not callable(true_fn): + raise TypeError("true_fn must be callable.") + if not callable(false_fn): + raise TypeError("false_fn must be callable.") + + with ops.name_scope(name, "cond", [pred]) as name: # Add the Switch to the graph. if isinstance(pred, bool): raise TypeError("pred must not be a Python bool") @@ -1756,18 +1788,18 @@ def cond(pred, fn1, fn2, strict=False, name=None): # Build the graph for the true branch in a new context. context_t = CondContext(pred, pivot_1, branch=1) context_t.Enter() - orig_res_t, res_t = context_t.BuildCondBranch(fn1) + orig_res_t, res_t = context_t.BuildCondBranch(true_fn) if orig_res_t is None: - raise ValueError("fn1 must have a return value.") + raise ValueError("true_fn must have a return value.") context_t.ExitResult(res_t) context_t.Exit() # Build the graph for the false branch in a new context. context_f = CondContext(pred, pivot_2, branch=0) context_f.Enter() - orig_res_f, res_f = context_f.BuildCondBranch(fn2) + orig_res_f, res_f = context_f.BuildCondBranch(false_fn) if orig_res_f is None: - raise ValueError("fn2 must have a return value.") + raise ValueError("false_fn must have a return value.") context_f.ExitResult(res_f) context_f.Exit() @@ -1780,14 +1812,14 @@ def cond(pred, fn1, fn2, strict=False, name=None): nest.assert_same_structure(orig_res_t, orig_res_f) except TypeError as e: raise TypeError( - "Incompatible return types of fn1 and fn2: {}".format(e)) + "Incompatible return types of true_fn and false_fn: {}".format(e)) except ValueError as e: raise ValueError( - "Incompatible return values of fn1 and fn2: {}".format(e)) + "Incompatible return values of true_fn and false_fn: {}".format(e)) # Add the final merge to the graph. if not res_t: - raise ValueError("fn1 and fn2 must return at least one result.") + raise ValueError("true_fn and false_fn must return at least one result.") res_t_flat = nest.flatten(res_t) res_f_flat = nest.flatten(res_f) @@ -1801,8 +1833,9 @@ def cond(pred, fn1, fn2, strict=False, name=None): val_x = x if isinstance(x, ops.Tensor) else x.values val_y = y if isinstance(y, ops.Tensor) else y.values if val_x.dtype.base_dtype != val_y.dtype.base_dtype: - raise ValueError("Outputs of fn1 and fn2 must have the same type: " - "%s, %s" % (val_x.dtype.name, val_y.dtype.name)) + raise ValueError( + "Outputs of true_fn and false_fn must have the same type: %s, %s" % + (val_x.dtype.name, val_y.dtype.name)) merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)] merges = _convert_flows_to_tensorarrays(nest.flatten(orig_res_t), merges) @@ -1817,6 +1850,7 @@ def cond(pred, fn1, fn2, strict=False, name=None): if not strict: merges = _UnpackIfSingleton(merges) return merges +# pylint: enable=g-doc-args def _resource_safe_shape(t): @@ -2548,12 +2582,16 @@ def while_loop(cond, body, loop_vars, shape_invariants=None, `cond` and `body`. `cond` and `body` both take as many arguments as there are `loop_vars`. - While `cond` evaluates to true, `body` is executed. - In addition to regular Tensors or IndexedSlices, the body may accept and return TensorArray objects. The flows of the TensorArray objects will be appropriately forwarded between loops and during gradient calculations. + Note that `while_loop` calls `cond` and `body` *exactly once* (inside the + call to `while_loop`, and not at all during `Session.run()`). `while_loop` + stitches together the graph fragments created during the `cond` and `body` + calls with some additional graph nodes to make something the repeats + `body` until `cond` returns false. + For correctness, `tf.while_loop()` strictly enforces shape invariants for the loop variables. A shape invariant is a (possibly partial) shape that is unchanged across the iterations of the loop. An error will be raised @@ -2882,10 +2920,10 @@ def case(pred_fn_pairs, default, exclusive=False, strict=False, name="case"): operation returns the tensors generated by `default`. `tf.case` supports nested structures as implemented in - `tensorflow.python.util.nest`. Both `fn1` and `fn2` must return the same + `tensorflow.python.util.nest`. All of the callables must return the same (possibly nested) value structure of lists, tuples, and/or named tuples. Singleton lists and tuples form the only exceptions to this: when returned by - `fn1` and/or `fn2`, they are implicitly unpacked to single values. This + a callable, they are implicitly unpacked to single values. This behavior is disabled by passing `strict=True`. Example 1: @@ -2913,9 +2951,6 @@ def case(pred_fn_pairs, default, exclusive=False, strict=False, name="case"): Expressions: ``` - x = tf.constant(0) - y = tf.constant(1) - z = tf.constant(2) def f1(): return tf.constant(17) def f2(): return tf.constant(23) def f3(): return tf.constant(-1) diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index 7704254b013..4e95783e5a8 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -324,6 +324,69 @@ class SwitchTestCase(TensorFlowTestCase): self.assertEquals(grad_x_false.eval(), 0.) +class CondTest(TensorFlowTestCase): + + def testCondTrue(self): + with self.test_session(): + x = constant_op.constant(2) + y = constant_op.constant(5) + z = control_flow_ops.cond( + math_ops.less(x, y), lambda: math_ops.multiply(x, 17), + lambda: math_ops.add(y, 23)) + self.assertEquals(z.eval(), 34) + + def testCondFalse(self): + with self.test_session(): + x = constant_op.constant(2) + y = constant_op.constant(1) + z = control_flow_ops.cond( + math_ops.less(x, y), lambda: math_ops.multiply(x, 17), + lambda: math_ops.add(y, 23)) + self.assertEquals(z.eval(), 24) + + def testCondTrueLegacy(self): + with self.test_session(): + x = constant_op.constant(2) + y = constant_op.constant(5) + z = control_flow_ops.cond( + math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17), + fn2=lambda: math_ops.add(y, 23)) + self.assertEquals(z.eval(), 34) + + def testCondFalseLegacy(self): + with self.test_session(): + x = constant_op.constant(2) + y = constant_op.constant(1) + z = control_flow_ops.cond( + math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17), + fn2=lambda: math_ops.add(y, 23)) + self.assertEquals(z.eval(), 24) + + def testCondMissingArg1(self): + with self.test_session(): + x = constant_op.constant(1) + with self.assertRaises(TypeError): + control_flow_ops.cond(True, false_fn=lambda: x) + + def testCondMissingArg2(self): + with self.test_session(): + x = constant_op.constant(1) + with self.assertRaises(TypeError): + control_flow_ops.cond(True, lambda: x) + + def testCondDuplicateArg1(self): + with self.test_session(): + x = constant_op.constant(1) + with self.assertRaises(TypeError): + control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x) + + def testCondDuplicateArg2(self): + with self.test_session(): + x = constant_op.constant(1) + with self.assertRaises(TypeError): + control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x) + + class ContextTest(TensorFlowTestCase): def testCondContext(self): diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index c3c4145763b..fb636d9525f 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -718,7 +718,7 @@ tf_module { } member_method { name: "cond" - argspec: "args=[\'pred\', \'fn1\', \'fn2\', \'strict\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'pred\', \'true_fn\', \'false_fn\', \'strict\', \'name\', \'fn1\', \'fn2\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\', \'None\', \'None\'], " } member_method { name: "confusion_matrix" From 0135602ffd33f21d702ee7a3a55e86cd85af45ca Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 1 May 2017 16:09:21 -0800 Subject: [PATCH 37/38] Add doc for saved_model_cli Change: 154781345 --- .../docs_src/programmers_guide/leftnav_files | 1 + .../programmers_guide/saved_model_cli.md | 251 ++++++++++++++++++ tensorflow/python/tools/saved_model_cli.py | 3 +- 3 files changed, 254 insertions(+), 1 deletion(-) create mode 100644 tensorflow/docs_src/programmers_guide/saved_model_cli.md diff --git a/tensorflow/docs_src/programmers_guide/leftnav_files b/tensorflow/docs_src/programmers_guide/leftnav_files index d3979172190..322e11cbd69 100644 --- a/tensorflow/docs_src/programmers_guide/leftnav_files +++ b/tensorflow/docs_src/programmers_guide/leftnav_files @@ -8,6 +8,7 @@ supervisor.md debugger.md tfdbg-tflearn.md meta_graph.md +saved_model_cli.md version_semantics.md data_versions.md faq.md diff --git a/tensorflow/docs_src/programmers_guide/saved_model_cli.md b/tensorflow/docs_src/programmers_guide/saved_model_cli.md new file mode 100644 index 00000000000..eb9e60e42ed --- /dev/null +++ b/tensorflow/docs_src/programmers_guide/saved_model_cli.md @@ -0,0 +1,251 @@ +# SavedModel CLI (Command-Line Interface) + +[`SavedModel`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md) +is a universal serialization format for Tensorflow. It provides a +language-neutral format to save machine-learned models and enables higher-level +systems and tools to produce, consume and transform TensorFlow models. + +We provide SavedModel CLI(command-line interface) as a tool to inspect and +execute a [`MetaGraph`](https://www.tensorflow.org/programmers_guide/meta_graph) +in a SavedModel. You can inspect for example, what +[`SignatureDefs`](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/g3doc/signature_defs.md), +including their input and output tensors, are in the model without writing any +code. This can be useful in situations such as when you want to quickly check +your input dtype and shape match with the model. Moreover, if you want to test +out the model, it also allows you to do a sanity check by passing in sample +inputs in the format of for example, python expressions, and fetch the outputs +simply through command line. + +## Get SavedModel CLI + +If TensorFlow is installed on your system through pip, the `saved_model_cli` +binary can be invoked directly from command line. + +To build the binary from source, run the following command: + +``` +$bazel build tensorflow/python/tools:saved_model_cli +``` + +## Commands + +SavedModel CLI allows users to both show and run computations on a +[`MetaGraphDef`](https://www.tensorflow.org/code/tensorflow/core/protobuf/meta_graph.proto) +in a SavedModel. These are done through `show` and `run` commands. We will +explain the usages of both commands with detailed examples. SavedModel CLI will +also display this information with `-h` option. + +### `show` command + +A SavedModel contains one or more MetaGraphs, identified by their tag-sets. Each +MetaGraph contains both a TensorFlow GraphDef as well as associated metadata +necessary for running computation in a graph. In order to serve a model, you +might wonder what kind of SignatureDefs are in each model, and what are their +inputs and outputs etc. The `show` command let you examine the content of the +SavedModel in a hierarchical order. + +``` +usage: saved_model_cli show [-h] --dir DIR [--all] +[--tag_set TAG_SET] [--signature_def SIGNATURE_DEF_KEY] +``` + +#### Examples + +To show all available MetaGraphDef tag-sets in the SavedModel: + +``` +$saved_model_cli show --dir /tmp/saved_model_dir +The given SavedModel contains the following tag-sets: +serve +serve, gpu +``` + +To show all available SignatureDef keys in a MetaGraphDef: + +``` +$saved_model_cli show --dir /tmp/saved_model_dir --tag_set serve +The given SavedModel MetaGraphDef contains SignatureDefs with the following keys: +SignatureDef key: "classify_x2_to_y3" +SignatureDef key: "classify_x_to_y" +SignatureDef key: "regress_x2_to_y3" +SignatureDef key: "regress_x_to_y" +SignatureDef key: "regress_x_to_y2" +SignatureDef key: "serving_default" +``` + +For a MetaGraphDef with multiple tags in the tag-set, all tags must be passed +in, separated by ',': + +``` +$saved_model_cli show --dir /tmp/saved_model_dir --tag_set serve,gpu +``` + +To show all inputs and outputs TensorInfo for a specific SignatureDef, pass in +the SignatureDef key to `signature_def` option. This is very useful when you +want to know the tensor key value, dtype and shape of the input tensors for +executing the computation graph later. + +``` +$saved_model_cli show --dir \ +/tmp/saved_model_dir --tag_set serve --signature_def serving_default +The given SavedModel SignatureDef contains the following input(s): +inputs['x'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: x:0 +The given SavedModel SignatureDef contains the following output(s): +outputs['y'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: y:0 +Method name is: tensorflow/serving/predict +``` + +To show all available information in the SavedModel, use `--all` option: + +``` +$saved_model_cli show --dir /tmp/saved_model_dir --all +MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs: + +signature_def['classify_x2_to_y3']: +The given SavedModel SignatureDef contains the following input(s): +inputs['inputs'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: x2:0 +The given SavedModel SignatureDef contains the following output(s): +outputs['scores'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: y3:0 +Method name is: tensorflow/serving/classify + +... + +signature_def['serving_default']: +The given SavedModel SignatureDef contains the following input(s): +inputs['x'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: x:0 +The given SavedModel SignatureDef contains the following output(s): +outputs['y'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: y:0 +Method name is: tensorflow/serving/predict +``` + +### `run` command + +SavedModel CLI also allows you to run a graph computation by passing in inputs, +displaying, and saving the outputs. + +``` +usage: saved_model_cli run [-h] --dir DIR --tag_set TAG_SET --signature_def + SIGNATURE_DEF_KEY [--inputs INPUTS] + [--input_exprs INPUT_EXPRS] [--outdir OUTDIR] + [--overwrite] [--tf_debug] +``` + +Tensor keys are used to specify which input we are passing in the values for. +There are two ways to pass inputs to the model. With '--inputs' option, you can +pass in numpy ndarray by files. With '--input_exprs' option, you can use python +expressions as inputs. + +#### Input By File + +To pass in inputs by files, use `--inputs` option in the format of +`<input_key>=<filename>`, or `<input_key>=<filename>[<variable_name>]`. Each +input is separated by semicolon. File specified by `filename` will be loaded +using `numpy.load`. Inputs can be loaded from only `.npy`, `.npz` or pickle +files. The `variable_name` key is optional depending on the input file type as +descripted in more details below. + +When loading from a `.npy` file, which always contains a numpy ndarray, the +content will be directly assigned to the specified input tensor. If a +`variable_name` is specified, it will be ignored and a warning will be issued. + +When loading from a `.npz` zip file, user can specify which variable within the +zip file to load for the input tensor key with `variable_name`. If nothing is +specified, SavedModel CLI will check that only one file is included in the zip +file and load it for the specified input tensor key. + +When loading from a pickle file, if no `variable_name` is specified in the +square brackets, whatever that is inside the pickle file will be passed to the +specified input tensor key. Else SavedModel CLI will assume a dictionary is +stored in the pickle file and the value corresponding to the variable_name will +be used. + +#### Input By Python Expression + +To pass in inputs by python expressions, use `--input_exprs` option. `numpy` +module is available as `np`. For example, `input_key=np.ones((32, 32, 3))` or +`input_key=[[1], [2], [3]]`. This can be useful for when you don't have data +files lying around, but still want to sanity check the model with some simple +inputs that match the dtype and shape of the model signature. + +#### Save Output + +By default, SavedModel CLI will print outputs to console. If a directory is +passed to `--outdir` option, the outputs will be saved as npy files named after +output tensor keys under the given directory. Use `--overwite` to overwrite +existing output files. + +#### TensorFlow Debugger (tfdbg) Integration + +If `--tf_debug` option is set, SavedModel CLI will use TensorFlow Debugger +(tfdbg) to watch the intermediate Tensors and runtime GraphDefs while running +the SavedModel. + +#### Examples + +If we have a simple model that adds `x1` and `x2` to get output `y`, where all +tensors are of shape `(-1, 1)`, and we have two `npz` files. File +`/tmp/my_data1.npy` contains a numpy ndarray `[[1], [2], [3]]`, file +`/tmp/my_data2.npy` contains another numpy ndarray `[[0.5], [0.5], [0.5]]`. Now +let's run these two `npy` files through the model to get `y`: + +``` +$saved_model_cli run --dir /tmp/saved_model_dir --tag_set serve \ +--signature_def x1_x2_to_y --inputs x1=/tmp/my_data1.npy;x2=/tmp/my_data2.npy \ +--outdir /tmp/out +Result for output key y: +[[ 1.5] + [ 2.5] + [ 3.5]] +``` + +Similarly, we can run input tensors from `npz` file and pickle file, as well as +overwrite the previous output file: + +``` +$saved_model_cli run --dir /tmp/saved_model_dir --tag_set serve \ +--signature_def x1_x2_to_y \ +--inputs x1=/tmp/my_data1.npz[x];x2=/tmp/my_data2.pkl --outdir /tmp/out \ +--overwrite +Result for output key y: +[[ 1.5] + [ 2.5] + [ 3.5]] +``` + +You can also use python expression instead of input file. Here we replace input +`x2` with a python expression: + +``` +$saved_model_cli run --dir /tmp/saved_model_dir --tag_set serve \ +--signature_def x1_x2_to_y --inputs x1=/tmp/my_data1.npz[x] \ +--input_exprs 'x2=np.ones((3,1))' +Result for output key y: +[[ 2] + [ 3] + [ 4]] +``` + +To run model with TensorFlow Debugger on: + +``` +$saved_model_cli run --dir /tmp/saved_model_dir --tag_set serve \ +--signature_def serving_default --inputs x=/tmp/data.npz[x] --tf_debug +``` diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 1c55c5bf531..9dfafb77e7d 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -602,7 +602,8 @@ def create_parser(): '--signature_def serving_default ' '--inputs input1_key=/tmp/124.npz[x],input2_key=/tmp/123.npy' '--outdir=/out\n\n' - 'For more information about input file format, please see:\n') + 'For more information about input file format, please see:\n' + 'https://www.tensorflow.org/programmers_guide/saved_model_cli\n') parser_run = subparsers.add_parser( 'run', description=run_msg, formatter_class=argparse.RawTextHelpFormatter) parser_run.add_argument( From aebaf317ce50ef42ebaab91191224c872389a6f6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 1 May 2017 16:14:06 -0800 Subject: [PATCH 38/38] This CL removes the Graph.edge_set_ field. This field stores a set of the Edge* that are in a Graph. However, Graph already stores this information, in Graph.edges_. There's really no good reason to keep both of these collections. To convert everything to use Graph.edges_ instead of Graph.edge_set_, I defined a class which handled excluding nullptr from iteration of the edges_ vector. This caused changes to non-contractual behavior of the runtime (enumeration order), so the unit tests are updated to reflect this. On a real-world graph used by our team, which contains 13190 nodes and 20796 edges, this change reduced heap allocation from 39.1 MB to 38.0 MB, for a drop of about 3%. Change: 154781831 --- tensorflow/core/common_runtime/function.cc | 2 +- .../core/common_runtime/function_test.cc | 14 ++-- tensorflow/core/graph/graph.cc | 5 +- tensorflow/core/graph/graph.h | 69 +++++++++++++++++-- 4 files changed, 75 insertions(+), 15 deletions(-) diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 0f2e24690f3..13e20568fff 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -456,7 +456,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate( void DumpGraph(StringPiece label, const Graph* g) { // TODO(zhifengc): Change Graph to record #nodes. VLOG(1) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges " - << g->edges().size(); + << g->num_edges(); if (VLOG_IS_ON(2)) { for (const auto& line : str_util::Split(DebugString(g), '\n')) { VLOG(2) << "|| " << line; diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index bbf35590eb6..8f70ab8783c 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -424,7 +424,7 @@ TEST_F(FunctionLibraryRuntimeTest, ControlDeps) { n8 = NoOp() @ n4 n9 = Identity[T=float](n3) @ n8 n10 = Identity[T=float](n2) @ n8 - n11 = NoOp() @ n10, n9 + n11 = NoOp() @ n9, n10 n5 = Mul[T=float](n2, n2) @ n11 n6 = Add[T=float](n4, n5) } @@ -500,8 +500,8 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { OptimizeGraph(lib_, &g); const char* e2 = R"P( (n2:float, n3:float) -> (n9:float) { - n11 = Const[dtype=int32, value=Tensor<type: int32 shape: [0] values: >]() n10 = Const[dtype=float, value=Tensor<type: float shape: [] values: 2>]() + n11 = Const[dtype=int32, value=Tensor<type: int32 shape: [0] values: >]() n6 = Shape[T=float, out_type=int32](n2) n5 = Mul[T=float](n3, n10) n7 = BroadcastGradientArgs[T=int32](n6, n11) @@ -614,10 +614,10 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { n17 = Sum[T=float, Tidx=int32, keep_dims=false](n14, n16) n19 = SymbolicGradient[Tin={float, int32, float}, Tout={float, int32}, f=Sum[T=float, Tidx=int32, keep_dims=false]](n14, n16, n26) n21 = SymbolicGradient[Tin={float, float, float}, Tout={float, float}, f=Add[T=float]](n24, n25, n19) - n28 = Identity[T=float](n21:1) n27 = Identity[T=float](n21) - n6 = Identity[T=float](n28) + n28 = Identity[T=float](n21:1) n8 = Identity[T=float](n27) + n6 = Identity[T=float](n28) } )P"; EXPECT_EQ(e1, DebugString(g.get())); @@ -626,8 +626,8 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { const char* e2 = R"P( (n4:float, n3:float) -> (n25:float, n23:float) { n2 = Const[dtype=float, value=Tensor<type: float shape: [] values: 1>]() - n8 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]() n7 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]() + n8 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]() n19 = Shape[T=float, out_type=int32](n3) n9 = Add[T=float](n4, n3) n20 = Shape[T=float, out_type=int32](n4) @@ -641,10 +641,10 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { n16 = Reshape[T=float, Tshape=int32](n2, n15) n17 = Div[T=int32](n14, n15) n18 = Tile[T=float, Tmultiples=int32](n16, n17) - n24 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21) n22 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21:1) - n25 = Reshape[T=float, Tshape=int32](n24, n20) + n24 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21) n23 = Reshape[T=float, Tshape=int32](n22, n19) + n25 = Reshape[T=float, Tshape=int32](n24, n20) } )P"; EXPECT_EQ(e2, DebugString(g.get())); diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index e1657cb8622..a68a8f25093 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -344,7 +344,7 @@ const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) { CHECK(source->out_edges_.insert(e).second); CHECK(dest->in_edges_.insert(e).second); edges_.push_back(e); - edge_set_.insert(e); + ++num_edges_; return e; } @@ -354,8 +354,8 @@ void Graph::RemoveEdge(const Edge* e) { CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1}); CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1}); CHECK_EQ(e, edges_[e->id_]); + CHECK_GT(num_edges_, 0); - CHECK_EQ(edge_set_.erase(e), size_t{1}); edges_[e->id_] = nullptr; Edge* del = const_cast<Edge*>(e); @@ -365,6 +365,7 @@ void Graph::RemoveEdge(const Edge* e) { del->src_output_ = kControlSlot - 1; del->dst_input_ = kControlSlot - 1; free_edges_.push_back(del); + --num_edges_; } Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) { diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 11a49ec3b3d..bbb3af196d6 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -268,6 +268,66 @@ class Edge { int dst_input_; }; +// Allows for iteration of the edges of a Graph, by iterating the underlying +// Graph.edges_ vector while skipping over null entries. +class GraphEdgesIterable { + private: + const std::vector<Edge*>& edges_; + + public: + explicit GraphEdgesIterable(const std::vector<Edge*>& edges) + : edges_(edges) {} + + typedef Edge* value_type; + + class const_iterator { + private: + // The underlying iterator. + std::vector<value_type>::const_iterator iter_; + + // The end of the underlying iterator. + std::vector<value_type>::const_iterator end_; + + // Advances iter_ until it reaches a non-null item, or reaches the end. + void apply_filter() { + while (iter_ != end_ && *iter_ == nullptr) { + ++iter_; + } + } + + public: + const_iterator(std::vector<value_type>::const_iterator iter, + std::vector<value_type>::const_iterator end) + : iter_(iter), end_(end) { + apply_filter(); + } + + bool operator==(const const_iterator& other) const { + return iter_ == other.iter_; + } + + bool operator!=(const const_iterator& other) const { + return iter_ != other.iter_; + } + + // This is the prefix increment operator (++x), which is the operator + // used by C++ range iteration (for (x : y) ...). We intentionally do not + // provide a postfix increment operator. + const_iterator& operator++() { + ++iter_; + apply_filter(); + return *this; + } + + value_type operator*() { return *iter_; } + }; + + const_iterator begin() { + return const_iterator(edges_.begin(), edges_.end()); + } + const_iterator end() { return const_iterator(edges_.end(), edges_.end()); } +}; + // Thread compatible but not thread safe. class Graph { public: @@ -345,7 +405,7 @@ class Graph { // smaller than num_edge_ids(). If one needs to create an array of // edges indexed by edge ids, num_edge_ids() should be used as the // array's size. - int num_edges() const { return edges().size(); } + int num_edges() const { return num_edges_; } // Serialize the nodes starting at `from_node_id` to a GraphDef. void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const; @@ -381,7 +441,7 @@ class Graph { // Access to the set of all edges. Example usage: // for (const Edge* e : graph.edges()) { ... } - const EdgeSet& edges() const { return edge_set_; } + GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); } // The pre-defined nodes. enum { kSourceId = 0, kSinkId = 1 }; @@ -421,9 +481,8 @@ class Graph { // the edge with that id was removed from the graph. std::vector<Edge*> edges_; - // For ease of iteration, we currently just keep a set of all live - // edges. May want to optimize by removing this copy. - EdgeSet edge_set_; + // The number of entries in edges_ that are not nullptr. + int num_edges_ = 0; // Allocated but free nodes and edges. std::vector<Node*> free_nodes_;