diff --git a/configure.py b/configure.py index fcf359d061d..ef5051d275e 100644 --- a/configure.py +++ b/configure.py @@ -688,7 +688,8 @@ def set_tf_cunn_version(environ_cp): cudnn_path_from_ldconfig) if cudnn_path_from_ldconfig: cudnn_path_from_ldconfig = cudnn_path_from_ldconfig.group(1) - if os.path.exists('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)): + if os.path.exists('%s.%s' % (cudnn_path_from_ldconfig, + tf_cudnn_version)): cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig) break diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 5b6a18b6a69..5538052d020 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -290,6 +290,7 @@ filegroup( "//tensorflow/contrib/decision_trees/proto:all_files", "//tensorflow/contrib/distributions:all_files", "//tensorflow/contrib/eager/python:all_files", + "//tensorflow/contrib/estimator:all_files", "//tensorflow/contrib/factorization:all_files", "//tensorflow/contrib/factorization/kernels:all_files", "//tensorflow/contrib/ffmpeg:all_files", @@ -407,6 +408,7 @@ filegroup( "//tensorflow/python/eager:all_files", "//tensorflow/python/estimator:all_files", "//tensorflow/python/feature_column:all_files", + "//tensorflow/python/keras:all_files", "//tensorflow/python/kernel_tests:all_files", "//tensorflow/python/kernel_tests/distributions:all_files", "//tensorflow/python/ops/distributions:all_files", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index c454c94249b..334f867e478 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -374,6 +374,65 @@ void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers, status->status = Reset(opt->options, container_names); } +// This traverses the specified nodes in topological order to verify there are +// no cycles. Starting with inputless nodes, it visits nodes whose inputs have +// all been visited, and counts the total number of visited nodes. If there is a +// cycle, nodes in the cycle will never be visited, and the visited count will +// be less than the total node count. +Status ValidateNoCycles(const Graph& g) { + // TODO(nolivia): check this on a subset of the graph instead of all of it. + int total_num_nodes = g.num_node_ids(); + // A node is ready when all of its inputs have been visited. + std::vector ready; + std::vector pending_count(total_num_nodes, 0); + + for (int i = 0; i < total_num_nodes; ++i) { + const Node* n = g.FindNodeId(i); + if (n == nullptr) continue; + pending_count[i] = n->in_edges().size(); + if (n->IsMerge()) { + // While-loop cycles are legal cycles so we manually adjust the + // pending_count to make sure that the loop is visited. + for (const Edge* e : n->in_edges()) { + if (!e->IsControlEdge() && e->src()->IsNextIteration()) { + pending_count[i]--; + } + } + } + if (pending_count[i] == 0) { + ready.push_back(n); + } + } + + int processed = 0; + while (!ready.empty()) { + const Node* node = ready.back(); + ready.pop_back(); + ++processed; + + for (const Edge* out : node->out_edges()) { + const int output_id = out->dst()->id(); + pending_count[output_id]--; + if (pending_count[output_id] == 0) { + ready.push_back(out->dst()); + } + } + } + + if (processed < total_num_nodes) { + std::vector nodes_in_cycle; + for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3; + ++i) { + if (pending_count[i] != 0) { + nodes_in_cycle.push_back(g.FindNodeId(i)->name()); + } + } + return errors::InvalidArgument( + "Graph is invalid, contains a cycle with ", total_num_nodes - processed, + " nodes, including: ", str_util::Join(nodes_in_cycle, ", ")); + } + return Status::OK(); +} } // namespace } // namespace tensorflow @@ -2251,6 +2310,12 @@ static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { const Graph& graph = session->graph->graph; const auto num_nodes = graph.num_node_ids(); if (session->last_num_graph_nodes < num_nodes) { + status->status = tensorflow::ValidateNoCycles(session->graph->graph); + if (!status->status.ok()) { + session->graph->mu.unlock(); + return false; + } + GraphDef graph_def; *graph_def.mutable_versions() = graph.versions(); // Fill graph_def with nodes with ids in the range diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index d19583a3abe..72e0fe8a156 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -38,6 +38,7 @@ TFE_TensorHandle* TestMatrixTensorHandle() { TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteTensor(t); + TF_DeleteStatus(status); return th; } @@ -385,7 +386,8 @@ TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get())); std::unique_ptr - value_handle(TFE_NewTensorHandle(t.get(), status), TFE_DeleteTensorHandle); + value_handle(TFE_NewTensorHandle(t.get(), status), + TFE_DeleteTensorHandle); if (TF_GetCode(status) != TF_OK) return nullptr; TFE_OpAddInput(op, value_handle.get(), status); diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index adca6c76252..b8d36b89472 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -20,7 +20,6 @@ limitations under the License. namespace tensorflow { void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) { - // TODO(skyewm): make sure cycles are prevented mutex_lock l(graph->mu); graph->graph.AddControlEdge(&input->node, &op->node); } diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc index 1868207148d..82469261e5b 100644 --- a/tensorflow/cc/framework/gradients.cc +++ b/tensorflow/cc/framework/gradients.cc @@ -77,7 +77,7 @@ class SymbolicGradientBuilder { Status CallGradFunction(const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs); - + // Returns a list mapping whether each node in the graph is reachable // from outputs_. Keyed by node id. std::vector GetReachableNodes(); @@ -156,7 +156,7 @@ std::vector SymbolicGradientBuilder::GetReachableNodes() { reachable_nodes[out.node()->id()] = true; } } - + while (!queue.empty()) { Node* n = queue.front(); queue.pop_front(); diff --git a/tensorflow/cc/framework/testutil.cc b/tensorflow/cc/framework/testutil.cc index 25ee08f6762..57d573e3c5a 100644 --- a/tensorflow/cc/framework/testutil.cc +++ b/tensorflow/cc/framework/testutil.cc @@ -37,7 +37,7 @@ void GetTensor(const Scope& scope, Output tensor, Tensor* out) { } void GetTensors(const Scope& scope, const std::vector& assign_vars, - OutputList tensors, std::vector* out) { + const OutputList& tensors, std::vector* out) { ClientSession session(scope); TF_CHECK_OK(session.Run(assign_vars, nullptr)); TF_CHECK_OK(session.Run(tensors, out)); diff --git a/tensorflow/cc/framework/testutil.h b/tensorflow/cc/framework/testutil.h index ca57c0f0a40..a3e19870ec8 100644 --- a/tensorflow/cc/framework/testutil.h +++ b/tensorflow/cc/framework/testutil.h @@ -30,7 +30,7 @@ void GetTensors(const Scope& scope, OutputList tensors, // assign_vars are extra outputs that should be run // e.g. to assign values to variables. void GetTensors(const Scope& scope, const std::vector& assign_vars, - OutputList tensors, std::vector* out); + const OutputList& tensors, std::vector* out); /// Computes the output 'tensor', returning the resulting tensor in 'out'. void GetTensor(const Scope& scope, Output tensor, Tensor* out); diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 71995b2307e..6190bd624db 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -94,14 +94,17 @@ Status Literal::CopyRange(const Literal& src_literal, TF_RET_CHECK(ShapeUtil::Rank(src_shape) == src_base.size()); TF_RET_CHECK(ShapeUtil::Rank(dest_shape) == dest_base.size()); + if (ShapeUtil::Rank(src_shape) == 0 || ShapeUtil::Rank(dest_shape) == 0) { // If any of the two shapes are scalars, we can just call the StridedCopy() // directly, and we know we will be copying only one value. TF_RET_CHECK(copy_size.empty()); StridedCopy(dest_data, LinearIndex(dest_base), 0, src_data, src_literal.LinearIndex(src_base), 0, 1); - } else if (!ShapeUtil::HasZeroElements(dest_shape)) { - TF_RET_CHECK(!ShapeUtil::HasZeroElements(src_shape)); + } else if (!ShapeUtil::HasZeroElements(dest_shape) && + !ShapeUtil::HasZeroElements(src_shape)) { + // Perform copy if neither src literal nor dest literal has dimensions with + // zero element, otherwise it's a no-op. TF_RET_CHECK(src_base.size() == dest_base.size()); TF_RET_CHECK(src_base.size() == copy_size.size()); diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 447c494bfca..64513459186 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -237,6 +237,9 @@ class Literal { // The src_literal and this literal must have the same primitive type, // src_base+copy_size must fit the source literal dimensions, as well as // dest_base+copy_size must fit the destination literal dimensions. + // Note: if either src_literal or this literal contains dimensions with zero + // element, then copy_size must be 0 in these dimensions while the + // corresponding base indices being 0. Status Copy(const Literal& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index a33c0fe09dd..61ceac4f9a6 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -698,7 +698,7 @@ TEST_F(LiteralUtilTest, Copy) { for (const auto& layout : layouts) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), dimensions, layout); - auto blank = Literal::CreateFromShape(shape); + auto source = Literal::CreateFromShape(shape); const int64 zero_base[] = {0, 0, 0, 0}; const int64 step[] = {1, 1, 1, 1}; @@ -707,15 +707,15 @@ TEST_F(LiteralUtilTest, Copy) { source->Set(indexes, ++seqnr); return true; }; - ShapeUtil::ForEachIndex(source->shape(), zero_base, dimensions, step, init_proc); + auto blank = Literal::CreateFromShape(shape); const int64 src_base[] = {3, 1, 5, 7}; const int64 dest_base[] = {6, 4, 12, 2}; const int64 copy_size[] = {7, 8, 11, 9}; - TF_EXPECT_OK(blank->Copy(*source, src_base, dest_base, copy_size)); + std::vector source_indexes(TF_ARRAYSIZE(dimensions), 0); std::vector blank_indexes(TF_ARRAYSIZE(dimensions), 0); bool matched = true; @@ -730,6 +730,7 @@ TEST_F(LiteralUtilTest, Copy) { matched = (bval != 0 && bval == source->Get(source_indexes)); return matched; }; + ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step, check_proc); EXPECT_TRUE(matched); @@ -749,6 +750,30 @@ TEST_F(LiteralUtilTest, CopyScalars) { EXPECT_EQ(vect->Get({4}), 17); } +TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { + const Shape empty_r1_shape = ShapeUtil::MakeShape(F32, {0}); + const auto const_nine = Literal::CreateR1({9}); + const auto const_empty = Literal::CreateFromShape(empty_r1_shape); + + { + // Source contains dimension with zero elements. + const auto empty = Literal::CreateFromShape(empty_r1_shape); + auto nine = Literal::CreateR1({9}); + + TF_EXPECT_OK(nine->Copy(*empty, {0}, {0}, {0})); + EXPECT_TRUE(nine->Equal(*const_nine)); + } + + { + // Copy 0 element to destination with zero elements. + const auto empty = Literal::CreateFromShape(empty_r1_shape); + auto nine = Literal::CreateR1({9}); + + TF_EXPECT_OK(empty->Copy(*nine, {0}, {0}, {0})); + EXPECT_TRUE(empty->Equal(*const_empty)); + } +} + TEST_F(LiteralUtilTest, F16) { // Verify that the internal data views are consistent and that they // are in little endian format diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 839fe484888..8c7c2aa70ee 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -198,8 +198,8 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { std::unordered_map hlo_to_profile_idx; CollectProfileCandidates profile_candidates_for_computation( &hlo_to_profile_idx); - TF_RETURN_IF_ERROR(computation->root_instruction()->Accept( - &profile_candidates_for_computation)); + TF_RETURN_IF_ERROR( + computation->Accept(&profile_candidates_for_computation)); return hlo_to_profile_idx; } @@ -433,6 +433,10 @@ Status InitializeModuleHooks( StatusOr> CpuCompiler::Compile( std::unique_ptr module, se::StreamExecutor* stream_exec) { + const string timer_message = + "Compiling [" + module->name() + "] for CPU using JIT"; + ScopedLoggingTimer compiling_timer(timer_message, 1); + VLOG(1) << "Compiling: " << module->name(); TF_RET_CHECK(stream_exec != nullptr); std::call_once(llvm_command_line_options_initialized, diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index c5275ede651..06c94e19de2 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -240,6 +240,13 @@ void IrEmitter::InitializeIrFunction(const string& function_name) { compute_function_->addFnAttr(llvm::Attribute::OptimizeForSize); } + if (hlo_module_config_.debug_options().xla_enable_fast_math()) { + compute_function_->addFnAttr("unsafe-fp-math", "true"); + compute_function_->addFnAttr("no-infs-fp-math", "true"); + compute_function_->addFnAttr("no-nans-fp-math", "true"); + compute_function_->addFnAttr("no-signed-zeros-fp-math", "true"); + } + ir_builder_.SetInsertPoint(llvm::BasicBlock::Create( /*Context=*/module_->getContext(), /*Name=*/"entry", diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 202a0171dbe..a40eb6afc2f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -87,6 +87,9 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( } } + // TODO(b/65380986): Investigate if adding fast math flags for generated + // kernels makes sense. + llvm::BasicBlock* entry_bb = llvm::BasicBlock::Create(function->getContext(), "entry", function); // Emit a "return void" at entry_bb's end, and sets the insert point before diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 749badf3f23..b84284046b0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -201,6 +201,9 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( } kernel->addAttribute(temp_buffer_arg_no + 1, llvm::Attribute::NoAlias); + // TODO(b/65380986): Investigate if adding fast math flags for generated + // kernels makes sense. + // Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX // treats it as a CUDA kernel. llvm::NamedMDNode* nvvm_annotations_node = diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index a8265483492..9205f5dc4e8 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -332,6 +332,53 @@ TEST_F(HloEvaluatorTest, DoesBroadcastScalar) { LiteralTestUtil::ExpectEqual(*result, *output_literal); } +TEST_F(HloEvaluatorTest, DoesConcatenateSimple) { + HloComputation::Builder b(TestName()); + + HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{-1, -2}, {100, 200}}))); + HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{-2, -3}, {-100, -200}}))); + + std::vector operands = {operand1, operand2}; + + Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); + b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); + + HloModule module(TestName()); + auto computation = module.AddEntryComputation(b.Build()); + + std::unique_ptr result = + evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + + auto expected = + Literal::CreateR2({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { + HloComputation::Builder b(TestName()); + + HloInstruction* operand1 = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({100, 200}))); + HloInstruction* operand2 = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({}))); + + std::vector operands = {operand1, operand2}; + + Shape shape = ShapeUtil::MakeShape(S64, {2}); + b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); + + HloModule module(TestName()); + auto computation = module.AddEntryComputation(b.Build()); + + std::unique_ptr result = + evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + + auto expected = Literal::CreateR1({100, 200}); + LiteralTestUtil::ExpectEqual(*expected, *result); +} + TEST_F(HloEvaluatorTest, ConvertWithSameLayout) { HloComputation::Builder b(TestName()); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 24ef4e09e7c..ce9e0db77e1 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -512,7 +512,6 @@ HloInstruction::CreateSelectAndScatter( instruction->set_parent(fused_root->parent()); instruction->set_metadata(fused_root->metadata()); instruction->CloneAndFuseInternal(fused_root); - instruction->CheckFusionInstruction(); return instruction; } @@ -636,7 +635,6 @@ HloInstruction* HloInstruction::FuseInstructionInternal( } HloInstruction* fused_instruction = CloneAndFuseInternal(instruction_to_fuse, add_output); - CheckFusionInstruction(); return fused_instruction; } @@ -822,74 +820,6 @@ bool HloInstruction::HasSideEffect() const { } } -void HloInstruction::CheckFusionInstruction() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - - // The parent fusion instruction of the fusion computation must be 'this'. - HloComputation* fused_computation = fused_instructions_computation(); - CHECK_EQ(this, fused_computation->FusionInstruction()); - - // Fused root instruction and fused parameters must all be owned by the fusion - // computation. - bool root_owned = false; - const std::vector& fused_parameters_ = fused_parameters(); - const HloInstruction* fused_root_ = fused_expression_root(); - std::vector parameter_owned(fused_parameters_.size(), false); - for (auto& instruction : fused_computation->instructions()) { - if (fused_root_ == instruction.get()) { - CHECK(!root_owned); - root_owned = true; - } - for (int i = 0; i < fused_parameters_.size(); ++i) { - if (fused_parameters_[i] == instruction.get()) { - CHECK(!parameter_owned[i]); - parameter_owned[i] = true; - } - } - } - CHECK(root_owned); - // Make sure all the parameter_owned entries are set - for (int i = 0; i < parameter_owned.size(); i++) { - CHECK(parameter_owned[i]); - } - - // Fused root must have no users. - CHECK_EQ(0, fused_root_->user_count()); - - // All uses of fused instructions must be in the fusion computation, and every - // non-root instruction must have at least one use. - for (auto& instruction : fused_instructions_computation()->instructions()) { - if (instruction.get() != fused_root_) { - CHECK_GT(instruction->user_count(), 0); - for (auto& user : instruction->users()) { - CHECK_EQ(fused_computation, user->parent()); - } - } - } - - // Fused parameter instructions must be numbered contiguously and match up - // (shapes compatible) with their respective operand. - CHECK_EQ(operands_.size(), fused_parameters_.size()); - std::vector parameter_numbers(fused_parameters_.size(), false); - for (auto fused_param : fused_parameters_) { - int64 param_no = fused_param->parameter_number(); - CHECK_GE(param_no, 0); - CHECK_LT(param_no, fused_parameters_.size()); - CHECK(!parameter_numbers[param_no]); - parameter_numbers[param_no] = true; - CHECK(ShapeUtil::Compatible(fused_param->shape(), - operands_[param_no]->shape())); - } - // Make sure all the parameter_numbers entries were seen - for (int i = 0; i < parameter_numbers.size(); i++) { - CHECK(parameter_numbers[i]); - } - - // Operands must be distinct. - std::set operand_set(operands_.begin(), operands_.end()); - CHECK_EQ(operand_set.size(), operands_.size()); -} - /* static */ std::unique_ptr HloInstruction::CreateCall( const Shape& shape, tensorflow::gtl::ArraySlice operands, HloComputation* computation) { @@ -1194,7 +1124,6 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( ->AddEmbeddedComputation( computation_builder.Build(FindOrDie(old_to_new, fused_root_)))); new_instruction->set_parent(parent()); - new_instruction->CheckFusionInstruction(); return new_instruction; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index ca6f27bd40e..bd8b8ac9bd8 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -900,9 +900,6 @@ class HloInstruction { // instruction to make it a bitcast. bool CouldBeBitcast() const; - // CHECKs various invariants of a fusion instruction. - void CheckFusionInstruction() const; - // Get/Set the number of partitions per outer dimension (in order, starting // with outer-most dimension first). Currently used by the parallel cpu // backend to partition HLOs into parallel tasks. diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index c44be716cdf..d40fceb0765 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -130,6 +130,8 @@ class ShapeVerifier : public DfsHloVisitor { } Status HandleBroadcast(HloInstruction* broadcast) override { + TF_RET_CHECK(ShapeUtil::Rank(broadcast->operand(0)->shape()) == + broadcast->dimensions().size()); return tensorflow::Status::OK(); } @@ -290,6 +292,123 @@ string ComputationsToString( } // namespace +Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { + // The parent fusion instruction of the fusion computation must be 'fusion'. + HloComputation* fused_computation = fusion->fused_instructions_computation(); + if (fusion != fused_computation->FusionInstruction()) { + return FailedPrecondition( + "Instruction of fused computation does not match expected instruction " + "%s.", + fusion->ToString().c_str()); + } + + // Fused root instruction and fused parameters must all be owned by the fusion + // computation. + bool root_owned = false; + const std::vector& fused_parameters = + fusion->fused_parameters(); + const HloInstruction* fused_root = fusion->fused_expression_root(); + std::vector parameter_owned(fused_parameters.size(), false); + for (auto& instruction : fused_computation->instructions()) { + if (fused_root == instruction.get()) { + if (root_owned) { + return FailedPrecondition("Root appears more than once in %s.", + fusion->ToString().c_str()); + } + root_owned = true; + } + for (int i = 0; i < fused_parameters.size(); ++i) { + if (fused_parameters[i] == instruction.get()) { + if (parameter_owned[i]) { + return FailedPrecondition("Parameter appears more than once in %s.", + fusion->ToString().c_str()); + } + parameter_owned[i] = true; + } + } + } + if (!root_owned) { + return FailedPrecondition("Root not found in computation of %s.", + fusion->ToString().c_str()); + } + // Make sure all the parameter_owned entries are set + for (int i = 0; i < parameter_owned.size(); i++) { + if (!parameter_owned[i]) { + return FailedPrecondition("Parameter %d not found in computation of %s.", + i, fusion->ToString().c_str()); + } + } + + // Fused root must have no users. + if (fused_root->user_count() != 0) { + return FailedPrecondition("Root of %s may not have users.", + fusion->ToString().c_str()); + } + + // All uses of fused instructions must be in the fusion computation, and every + // non-root instruction must have at least one use. + for (auto& instruction : + fusion->fused_instructions_computation()->instructions()) { + if (instruction.get() != fused_root) { + if (instruction->user_count() == 0) { + return FailedPrecondition( + "Non-root instruction %s in %s must have users.", + instruction->ToString().c_str(), fusion->ToString().c_str()); + } + for (auto& user : instruction->users()) { + if (fused_computation != user->parent()) { + return FailedPrecondition( + "Non-root instruction %s in %s may not have external users.", + instruction->ToString().c_str(), fusion->ToString().c_str()); + } + } + } + } + + // Fused parameter instructions must be numbered contiguously and match up + // (shapes compatible) with their respective operand. + CHECK_EQ(fusion->operands().size(), fused_parameters.size()); + std::vector parameter_numbers(fused_parameters.size(), false); + for (auto fused_param : fused_parameters) { + int64 param_no = fused_param->parameter_number(); + if (param_no < 0) { + return FailedPrecondition( + "Unexpected negative parameter number %lld in %s.", param_no, + fusion->ToString().c_str()); + } + if (param_no >= fused_parameters.size()) { + return FailedPrecondition( + "Unexpected parameter number %lld in %s: higher then number of " + "parameters %lu.", + param_no, fusion->ToString().c_str(), fused_parameters.size()); + } + if (parameter_numbers[param_no]) { + return FailedPrecondition( + "Did not expect parameter number %lld more than once in %s.", + param_no, fusion->ToString().c_str()); + } + parameter_numbers[param_no] = true; + if (!ShapeUtil::Compatible(fused_param->shape(), + fusion->operand(param_no)->shape())) { + return FailedPrecondition( + "Shape mismatch between parameter number %lld and its operand in %s.", + param_no, fusion->ToString().c_str()); + } + } + // Make sure all the parameter_numbers entries were seen + for (int i = 0; i < parameter_numbers.size(); i++) { + if (!parameter_numbers[i]) { + return FailedPrecondition("Did not see parameter number %d in %s.", i, + fusion->ToString().c_str()); + } + } + + // TODO(b/65423525): We'd like to check that all operands are distinct. + // This is currently disabled due to the invariant being violated by + // multi-output fusion. + return tensorflow::Status::OK(); +} + StatusOr HloVerifier::Run(HloModule* module) { tensorflow::gtl::FlatMap instructions; ShapeVerifier shape_verifier(shape_size_fn_); @@ -298,6 +417,7 @@ StatusOr HloVerifier::Run(HloModule* module) { for (const auto& instruction : computation->instructions()) { TF_RET_CHECK(instruction->parent() == computation.get()); if (instruction->opcode() == HloOpcode::kFusion) { + TF_RETURN_IF_ERROR(CheckFusionInstruction(instruction.get())); TF_RET_CHECK( ContainersEqual(instruction->called_computations(), {instruction->fused_instructions_computation()})) diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index bc6800dae54..e35a7f3642c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -34,6 +34,9 @@ class HloVerifier : public HloPassInterface { StatusOr Run(HloModule* module) override; private: + // CHECKs various invariants of a fusion instruction. + Status CheckFusionInstruction(HloInstruction* fusion) const; + // Returns the size of a Shape in bytes. const std::function shape_size_fn_; }; diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc index 607abee33d9..064020896e7 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc @@ -425,7 +425,6 @@ TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInHeadOfFusionNode) { EXPECT_EQ(computation->root_instruction(), z); HloInstruction* y_fused = z->fused_expression_root(); EXPECT_EQ(y_fused->opcode(), HloOpcode::kCos); - z->CheckFusionInstruction(); // This should see that the fusion computation contains a kCos operation, // and insert a new reduce-precision node at its input. @@ -450,7 +449,6 @@ TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInHeadOfFusionNode) { EXPECT_EQ(computation->root_instruction(), z); EXPECT_THAT(z->fused_expression_root(), y_fused); EXPECT_THAT(y_fused->operand(0), op::ReducePrecision(op::Parameter())); - z->CheckFusionInstruction(); } TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInTailOfFusionNode) { @@ -468,7 +466,6 @@ TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInTailOfFusionNode) { shape, HloInstruction::FusionKind::kLoop, y)); EXPECT_IS_OK(computation->ReplaceUsesOfInstruction(y, z)); EXPECT_IS_OK(computation->RemoveInstruction(y)); - z->CheckFusionInstruction(); // Confirm expected graph before adding reduce-precision ops. EXPECT_THAT(x->users(), UnorderedElementsAre(z)); @@ -498,7 +495,6 @@ TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInTailOfFusionNode) { EXPECT_THAT(x->users(), UnorderedElementsAre(z)); EXPECT_EQ(computation->root_instruction(), z); EXPECT_THAT(z->fused_expression_root(), op::ReducePrecision(y_fused)); - z->CheckFusionInstruction(); } TEST_F(ReducePrecisionInsertionTest, MakeFilterFunctionNoSubstrings) { diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 297bfd93d12..858db8fa0e1 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -2471,8 +2471,8 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( operand->shape().element_type(), AsInt64Slice(output_shape.dimensions())); // Do explicit broadcast for scalar. if (ShapeUtil::IsScalar(operand->shape())) { - return hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast( - broadcast_shape, operand, AsInt64Slice(broadcast_shape.dimensions()))); + return hlo_builder_.AddInstruction( + HloInstruction::CreateBroadcast(broadcast_shape, operand, {})); } // Do explicit broadcast for degenerate broadcast. std::vector broadcast_dimensions; diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 606d801c84e..22d2b917a1d 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -67,7 +67,7 @@ class MultiOutputFusionTest : public HloTestBase { elem_shape0, HloOpcode::kAdd, param0, const0)); HloInstruction* broadcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(elem_shape2, add1, {0, 1})); + HloInstruction::CreateBroadcast(elem_shape2, add1, {})); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, elem_shape2, "1")); diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 84fcc0d0149..11e4ea888c7 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -24,6 +24,7 @@ py_library( "//tensorflow/contrib/deprecated:deprecated_py", "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/eager/python:tfe", + "//tensorflow/contrib/estimator:estimator_py", "//tensorflow/contrib/factorization:factorization_py", "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", "//tensorflow/contrib/framework:framework_py", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index d1d0e2823ad..5b3f0b3f6ee 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -29,6 +29,7 @@ from tensorflow.contrib import cudnn_rnn from tensorflow.contrib import data from tensorflow.contrib import deprecated from tensorflow.contrib import distributions +from tensorflow.contrib import estimator from tensorflow.contrib import factorization from tensorflow.contrib import framework from tensorflow.contrib import gan diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 1b706159a3d..ce94f718a10 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -218,6 +218,48 @@ add_python_module("tensorflow/python/estimator/inputs/queues") add_python_module("tensorflow/python/feature_column") add_python_module("tensorflow/python/framework") add_python_module("tensorflow/python/grappler") +add_python_module("tensorflow/python/keras") +add_python_module("tensorflow/python/keras/activations") +add_python_module("tensorflow/python/keras/applications") +add_python_module("tensorflow/python/keras/applications/inception_v3") +add_python_module("tensorflow/python/keras/applications/mobilenet") +add_python_module("tensorflow/python/keras/applications/resnet50") +add_python_module("tensorflow/python/keras/applications/vgg16") +add_python_module("tensorflow/python/keras/applications/vgg19") +add_python_module("tensorflow/python/keras/applications/xception") +add_python_module("tensorflow/python/keras/backend") +add_python_module("tensorflow/python/keras/callbacks") +add_python_module("tensorflow/python/keras/constraints") +add_python_module("tensorflow/python/keras/datasets") +add_python_module("tensorflow/python/keras/datasets/boston_housing") +add_python_module("tensorflow/python/keras/datasets/cifar10") +add_python_module("tensorflow/python/keras/datasets/cifar100") +add_python_module("tensorflow/python/keras/datasets/imdb") +add_python_module("tensorflow/python/keras/datasets/mnist") +add_python_module("tensorflow/python/keras/datasets/reuters") +add_python_module("tensorflow/python/keras/initializers") +add_python_module("tensorflow/python/keras/layers") +add_python_module("tensorflow/python/keras/losses") +add_python_module("tensorflow/python/keras/metrics") +add_python_module("tensorflow/python/keras/models") +add_python_module("tensorflow/python/keras/optimizers") +add_python_module("tensorflow/python/keras/preprocessing") +add_python_module("tensorflow/python/keras/preprocessing/image") +add_python_module("tensorflow/python/keras/preprocessing/sequence") +add_python_module("tensorflow/python/keras/preprocessing/text") +add_python_module("tensorflow/python/keras/regularizers") +add_python_module("tensorflow/python/keras/utils") +add_python_module("tensorflow/python/keras/wrappers") +add_python_module("tensorflow/python/keras/wrappers/scikit_learn") +add_python_module("tensorflow/python/keras/_impl") +add_python_module("tensorflow/python/keras/_impl/keras") +add_python_module("tensorflow/python/keras/_impl/keras/applications") +add_python_module("tensorflow/python/keras/_impl/keras/datasets") +add_python_module("tensorflow/python/keras/_impl/keras/engine") +add_python_module("tensorflow/python/keras/_impl/keras/layers") +add_python_module("tensorflow/python/keras/_impl/keras/preprocessing") +add_python_module("tensorflow/python/keras/_impl/keras/utils") +add_python_module("tensorflow/python/keras/_impl/keras/wrappers") add_python_module("tensorflow/python/kernel_tests") add_python_module("tensorflow/python/kernel_tests/distributions") add_python_module("tensorflow/python/layers") @@ -299,6 +341,9 @@ add_python_module("tensorflow/contrib/distributions/python") add_python_module("tensorflow/contrib/distributions/python/kernel_tests") add_python_module("tensorflow/contrib/distributions/python/ops") add_python_module("tensorflow/contrib/distributions/python/ops/bijectors") +add_python_module("tensorflow/contrib/estimator") +add_python_module("tensorflow/contrib/estimator/python") +add_python_module("tensorflow/contrib/estimator/python/estimator") add_python_module("tensorflow/contrib/factorization") add_python_module("tensorflow/contrib/factorization/examples") add_python_module("tensorflow/contrib/factorization/kernels") diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index eb02f20457e..9dff8881559 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -142,6 +142,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/debug/cli/*_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/*_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/wrappers/*_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/estimator/python/estimator/*_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/*.py" "${tensorflow_source_dir}/tensorflow/python/meta_graph_transform/*_test.py" "${tensorflow_source_dir}/tensorflow/python/profiler/*_test.py" @@ -246,6 +247,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) # Broken tensorboard test due to cmake issues. "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Needs portpicker + "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py" # b/65430561 # tensor_forest tests (also note that we exclude the hybrid tests for now) "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order. "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order. diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index bc4fd10cac6..f6eeb016755 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -693,6 +693,7 @@ _cudnn_rnn_common_doc_string = """ canonical format. This is a typical use case: + * The user creates a CudnnRNN model. * The user query that parameter buffer size. * The user creates a variable of that size that serves as the parameter diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index 7b916d82c1c..c417650a96f 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -10,6 +10,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:sloppy_ops", "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 1c0a5288f7e..c74e1369d5d 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -23,6 +23,8 @@ @@read_batch_features @@rejection_resample @@group_by_window +@@sloppy_interleave +@@sloppy_map """ from __future__ import absolute_import @@ -38,6 +40,7 @@ from tensorflow.contrib.data.python.ops.dataset_ops import read_batch_features from tensorflow.contrib.data.python.ops.dataset_ops import rejection_resample from tensorflow.contrib.data.python.ops.dataset_ops import TextLineDataset from tensorflow.contrib.data.python.ops.dataset_ops import TFRecordDataset +from tensorflow.contrib.data.python.ops.sloppy_ops import sloppy_interleave # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index fb2740ffef2..2f93c345027 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -146,6 +146,25 @@ py_test( ], ) +py_test( + name = "sloppy_transformation_dataset_op_test", + size = "small", + srcs = ["sloppy_transformation_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:sloppy_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + ], +) + py_test( name = "list_files_dataset_op_test", size = "small", @@ -228,7 +247,7 @@ py_test( srcs = ["sql_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/data", + "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index d05fbb7d285..4c1496ccf98 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from collections import namedtuple import os import threading @@ -489,8 +490,8 @@ class MapDatasetTest(test.TestCase): dataset_tuple = dataset_ops.Dataset.zip((labels, images)) # convert dataset of tuples to dataset of namedtuples - Example = namedtuple("Example", ["label", "image"]) - dataset_namedtuple = dataset_tuple.map(Example) + example = namedtuple("Example", ["label", "image"]) + dataset_namedtuple = dataset_tuple.map(example) def preprocess_tuple(label, image): image = 2 * image diff --git a/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py new file mode 100644 index 00000000000..f9198bacfbd --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py @@ -0,0 +1,475 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import math +import threading +import time + +from six.moves import zip_longest + +from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.ops import sloppy_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import script_ops +from tensorflow.python.platform import test + + +class SloppyInterleaveDatasetTest(test.TestCase): + + def setUp(self): + self.input_values = array_ops.placeholder(dtypes.int64, shape=[None]) + self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) + self.block_length = array_ops.placeholder(dtypes.int64, shape=[]) + + self.repeat_count = 2 + + # Set up threading events used to sequence when items are produced that + # are subsequently interleaved. These events allow us to deterministically + # simulate slowdowns and force sloppiness. + self.read_coordination_events = {} + self.write_coordination_events = {} + # input values [4, 5, 6] are the common case for the tests; set defaults + for i in range(4, 7): + self.read_coordination_events[i] = threading.Semaphore(0) + self.write_coordination_events[i] = threading.Event() + + def map_py_fn(x): + self.write_coordination_events[x].wait() + self.write_coordination_events[x].clear() + self.read_coordination_events[x].release() + return x * x + + def map_fn(x): + return script_ops.py_func(map_py_fn, [x], x.dtype) + + def interleave_fn(x): + dataset = dataset_ops.Dataset.from_tensors(x) + dataset = dataset.repeat(x) + return dataset.map(map_fn) + + self.dataset = (dataset_ops.Dataset.from_tensor_slices(self.input_values) + .repeat(self.repeat_count).apply( + sloppy_ops.sloppy_interleave, + args=(interleave_fn, self.cycle_length, + self.block_length))) + self.iterator = self.dataset.make_initializable_iterator() + self.init_op = self.iterator.initializer + self.next_element = self.iterator.get_next() + + def _interleave(self, lists, cycle_length, block_length): + """Python implementation of interleave used for testing.""" + num_open = 0 + + # `all_iterators` acts as a queue of iterators over each element of `lists`. + all_iterators = [iter(l) for l in lists] + + # `open_iterators` are the iterators whose elements are currently being + # interleaved. + open_iterators = [] + for i in range(cycle_length): + if all_iterators: + open_iterators.append(all_iterators.pop(0)) + num_open += 1 + else: + open_iterators.append(None) + + while num_open or all_iterators: + for i in range(cycle_length): + if open_iterators[i] is None: + if all_iterators: + open_iterators[i] = all_iterators.pop(0) + num_open += 1 + else: + continue + for _ in range(block_length): + try: + yield next(open_iterators[i]) + except StopIteration: + open_iterators[i] = None + num_open -= 1 + break + + def testPythonImplementation(self): + input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], + [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]] + + # Cycle length 1 acts like `Dataset.flat_map()`. + expected_elements = itertools.chain(*input_lists) + for expected, produced in zip(expected_elements, + self._interleave(input_lists, 1, 1)): + self.assertEqual(expected, produced) + + # Cycle length > 1. + expected_elements = [ + 4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, + 6, 5, 6, 5, 6, 6 + ] + for index, (expected, produced) in enumerate( + zip_longest(expected_elements, self._interleave(input_lists, 2, 1))): + self.assertEqual(expected, produced, "Values differ at %s. %s != %s" % + (index, expected, produced)) + + def testPythonImplementationBlockLength(self): + input_lists = [[4] * 4, [5] * 5, [6] * 6] * 2 + expected_elements = [ + 4, 4, 5, 5, 4, 4, 5, 5, 5, 6, 6, 4, 4, 6, 6, 4, 4, 6, 6, 5, 5, 6, 6, 5, + 5, 6, 6, 5, 6, 6 + ] + for index, (expected, produced) in enumerate( + zip_longest(expected_elements, self._interleave(input_lists, 2, 2))): + self.assertEqual(expected, produced, "Values differ at %s. %s != %s" % + (index, expected, produced)) + + def testPythonImplementationEmptyLists(self): + input_lists = [[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4], [], + [6, 6, 6, 6, 6, 6]] + + expected_elements = [ + 4, 4, 6, 4, 6, 4, 6, 6, 4, 6, 4, 6, 4, 4, 6, 6, 6, 6, 6, 6 + ] + for index, (expected, produced) in enumerate( + zip_longest(expected_elements, self._interleave(input_lists, 2, 1))): + self.assertEqual(expected, produced, "Values differ at %s. %s != %s" % + (index, expected, produced)) + + def _clear_coordination_events(self): + for i in range(4, 7): + self.read_coordination_events[i] = threading.Semaphore(0) + self.write_coordination_events[i].clear() + + def _allow_all_map_threads(self): + for i in range(4, 7): + self.write_coordination_events[i].set() + + def testSingleThreaded(self): + # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and + # `Dataset.flat_map()` and is single-threaded. No synchronization required. + with self.test_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 1, + self.block_length: 1 + }) + + for expected_element in self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 1): + self.write_coordination_events[expected_element].set() + self.assertEqual(expected_element * expected_element, + sess.run(self.next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testTwoThreadsNoContention(self): + # num_threads > 1. + # Explicit coordination should result in `Dataset.interleave()` behavior + with self.test_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1 + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, + 1)): + self.write_coordination_events[expected_element].set() + if done_first_event: # First event starts the worker threads. + self.read_coordination_events[expected_element].acquire() + actual_element = sess.run(self.next_element) + if not done_first_event: + self.read_coordination_events[expected_element].acquire() + done_first_event = True + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testTwoThreadsNoContentionWithRaces(self): + """Tests where all the workers race in producing elements. + + Note: this is in contrast with the prevous test which carefully sequences + the execution of the map functions. + """ + with self.test_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1 + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, + 1)): + if done_first_event: # First event starts the worker threads. + self._allow_all_map_threads() + self.read_coordination_events[expected_element].acquire() + else: + self.write_coordination_events[expected_element].set() + time.sleep(0.01) # Sleep to consistently "avoid" the race condition. + actual_element = sess.run(self.next_element) + if not done_first_event: + done_first_event = True + self.assertTrue( + self.read_coordination_events[expected_element].acquire(False)) + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testTwoThreadsNoContentionBlockLength(self): + # num_threads > 1. + # Explicit coordination should result in `Dataset.interleave()` behavior + with self.test_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 2 + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, + 2)): + self.write_coordination_events[expected_element].set() + if done_first_event: # First event starts the worker threads. + self.read_coordination_events[expected_element].acquire() + actual_element = sess.run(self.next_element) + if not done_first_event: + done_first_event = True + self.read_coordination_events[expected_element].acquire() + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testTwoThreadsNoContentionWithRacesAndBlocking(self): + """Tests where all the workers race in producing elements. + + Note: this is in contrast with the prevous test which carefully sequences + the execution of the map functions. + """ + with self.test_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 2 + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, + 2)): + if done_first_event: # First event starts the worker threads. + self._allow_all_map_threads() + self.read_coordination_events[expected_element].acquire() + else: + self.write_coordination_events[expected_element].set() + time.sleep(0.01) # Sleep to consistently "avoid" the race condition. + actual_element = sess.run(self.next_element) + if not done_first_event: + done_first_event = True + self.assertTrue( + self.read_coordination_events[expected_element].acquire(False)) + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testEmptyInput(self): + with self.test_session() as sess: + # Empty input. + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [], + self.cycle_length: 2, + self.block_length: 3 + }) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testNonEmptyInputIntoEmptyOutputs(self): + # Non-empty input leading to empty output. + with self.test_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [0, 0, 0], + self.cycle_length: 2, + self.block_length: 3 + }) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testPartiallyEmptyOutputs(self): + # Mixture of non-empty and empty interleaved datasets. + with self.test_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 0, 6], + self.cycle_length: 2, + self.block_length: 1 + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)): + self.write_coordination_events[expected_element].set() + if done_first_event: # First event starts the worker threads + self.read_coordination_events[expected_element].acquire() + actual_element = sess.run(self.next_element) + if not done_first_event: + done_first_event = True + self.read_coordination_events[expected_element].acquire() + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testDelayedOutput(self): + # Explicitly control the sequence of events to ensure we correctly avoid + # head-of-line blocking. + with self.test_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1 + }) + + mis_ordering = [ + 4, 4, 5, 4, 5, 5, 4, 5, 6, 6, 6, 5, 4, 4, 6, 6, 4, 4, 6, 5, 6, 6, 6, + 6, 5, 5, 5, 5, 6, 6 + ] + for element in mis_ordering: + self.write_coordination_events[element].set() + self.assertEqual(element * element, sess.run(self.next_element)) + self.assertTrue(self.read_coordination_events[element].acquire(False)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testBlockLengthWithContention(self): + with self.test_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 3 + }) + # Test against a generating sequence that differs from the uncontended + # case, in order to prove sloppy correctness. + for i, expected_element in enumerate( + self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, + cycle_length=2, + block_length=2)): + self.write_coordination_events[expected_element].set() + if done_first_event: # First event starts the worker threads. + self.read_coordination_events[expected_element].acquire() + actual_element = sess.run(self.next_element) + if not done_first_event: + self.read_coordination_events[expected_element].acquire() + done_first_event = True + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testEarlyExit(self): + # Exiting without consuming all input should not block + with self.test_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 3, + self.block_length: 2 + }) + for i in range(4, 7): + self.write_coordination_events[i].set() + elem = sess.run(self.next_element) # Start all workers + # Allow the one successful worker to progress beyond the py_func again. + elem = int(math.sqrt(elem)) + self.write_coordination_events[elem].set() + self.read_coordination_events[elem].acquire() + # Allow the prefetch to succeed + for i in range(4, 7): + self.read_coordination_events[i].acquire() + self.write_coordination_events[i].set() + + def testTooManyReaders(self): + + def interleave_fn(x): + dataset = dataset_ops.Dataset.from_tensors(x) + dataset = dataset.repeat(math_ops.cast(x, dtype=dtypes.int64)) + return dataset + + dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6]) + dataset = dataset.repeat(self.repeat_count) + dataset = dataset.apply( + sloppy_ops.sloppy_interleave, + args=(interleave_fn,), + kwargs={"cycle_length": 16, + "block_length": 2}) + iterator = dataset.make_one_shot_iterator() + + with self.test_session() as sess: + output_values = [] + for _ in range(30): + output_values.append(sess.run(iterator.get_next())) + + expected_values = self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2) + self.assertItemsEqual(output_values, expected_values) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 8afd122d82d..94969c1c704 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -32,6 +32,21 @@ py_library( ], ) +py_library( + name = "sloppy_ops", + srcs = ["sloppy_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_ops", + "//tensorflow/contrib/data/python/framework:function", + "//tensorflow/contrib/data/python/util:nest", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/data/python/ops/sloppy_ops.py b/tensorflow/contrib/data/python/ops/sloppy_ops.py new file mode 100644 index 00000000000..010bd31161f --- /dev/null +++ b/tensorflow/contrib/data/python/ops/sloppy_ops.py @@ -0,0 +1,120 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Non-deterministic dataset transformations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.framework import function +from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops + + +class SloppyInterleaveDataset(dataset_ops.Dataset): + """A `Dataset` that maps a function over its input and flattens the result.""" + + def __init__(self, input_dataset, map_func, cycle_length, block_length): + """See `tf.contrib.data.sloppy_interleave()` for details.""" + super(SloppyInterleaveDataset, self).__init__() + self._input_dataset = input_dataset + + @function.Defun(*nest.flatten(input_dataset.output_types)) + def tf_map_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + # Pass in shape information from the input_dataset. + for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)): + arg.set_shape(shape) + + nested_args = nest.pack_sequence_as(input_dataset.output_types, args) + + if nest.is_sequence(nested_args): + dataset = map_func(*nested_args) + else: + dataset = map_func(nested_args) + + if not isinstance(dataset, dataset_ops.Dataset): + raise TypeError("`map_func` must return a `Dataset` object.") + + self._output_types = dataset.output_types + self._output_shapes = dataset.output_shapes + + return dataset.make_dataset_resource() + + self._map_func = tf_map_func + self._map_func.add_to_graph(ops.get_default_graph()) + + self._cycle_length = ops.convert_to_tensor( + cycle_length, dtype=dtypes.int64, name="cycle_length") + self._block_length = ops.convert_to_tensor( + block_length, dtype=dtypes.int64, name="block_length") + + def make_dataset_resource(self): + return gen_dataset_ops.sloppy_interleave_dataset( + self._input_dataset.make_dataset_resource(), + self._map_func.captured_inputs, + self._cycle_length, + self._block_length, + f=self._map_func, + output_types=nest.flatten(self.output_types), + output_shapes=nest.flatten(self.output_shapes)) + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + +def sloppy_interleave(dataset, map_func, cycle_length, block_length): + """Maps `map_func` across `dataset`, and interleaves the results. + + The resulting dataset is almost identical to `interleave`. The key + difference being that if retrieving a value from a given output iterator would + cause `get_next` to block, that iterator will be skipped, and consumed + when next available. If consuming from all iterators would cause the + `get_next` call to block, the `get_next` call blocks until the first value is + available. + + If the underlying datasets produce elements as fast as they are consumed, the + `sloppy_interleave` dataset behaves identically to the `interleave` dataset. + However, if an underlying dataset would block the consumer, the + `sloppy_interleave` dataset can violate to the round-robin order (respected by + the `interleave` dataset), producing an element from a different underlying + dataset instead. + + WARNING: The order of elements in the resulting dataset is not + deterministic. Use `Dataset.interleave()` if you want the elements to have a + deterministic order. + + Args: + dataset: A `Dataset` that produces elements to feed to `map_func`. + map_func: A function mapping a nested structure of tensors (having shapes + and types defined by `self.output_shapes` and `self.output_types`) to a + `Dataset`. + cycle_length: The number of threads to interleave from in parallel. + block_length: The number of consecutive elements to pull from a thread + before advancing to the next thread. Note: sloppy_interleave will + skip the remainder of elements in the block_length in order to avoid + blocking. + + Returns: + A `Dataset`. + """ + return SloppyInterleaveDataset(dataset, map_func, cycle_length, block_length) diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index e29314099d7..1b831f8afba 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -2,11 +2,14 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + py_library( name = "tfe", srcs = ["tfe.py"], srcs_version = "PY2AND3", deps = [ + ":saver", "//tensorflow/python:framework_ops", "//tensorflow/python:util", "//tensorflow/python/eager:backprop", @@ -18,6 +21,28 @@ py_library( ], ) +py_library( + name = "saver", + srcs = ["saver.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:training", + ], +) + +cuda_py_test( + name = "saver_test", + srcs = ["saver_test.py"], + additional_deps = [ + ":saver", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/eager/python/saver.py b/tensorflow/contrib/eager/python/saver.py new file mode 100644 index 00000000000..12c902a4b66 --- /dev/null +++ b/tensorflow/contrib/eager/python/saver.py @@ -0,0 +1,122 @@ +"""Saver for eager mode TensorFlow.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib + +from tensorflow.python.framework import errors +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import saver as _saver + + +def _init_from_checkpoint(self, *args, **kwargs): + """Overrides default init by loading value from checkpoint.""" + self.old_init(*args, **kwargs) + # pylint: disable=protected-access + if self._shared_name not in self.ckpt_var_cache: + raise errors.NotFoundError(None, None, + "%s not found in checkpoint" % self._shared_name) + + val = self.ckpt_var_cache[self._shared_name] + if val is not None: + self.assign(self.ckpt_var_cache[self._shared_name]) + # Avoid assigning for the second time. + self.ckpt_var_cache[self._shared_name] = None + # pylint: enable=protected-access + + +class Saver(object): + """A simple tf.train.Saver adapter for eager mode. + + save and restore API are similar to the tf.train.Saver, except that + session is not needed. + + restore_on_create is eager mode's way to reload checkpoint value during + the execution. (unlike graph mode's reload before run). + + Args: + var_list: See tf.train.Saver. Works the same for save/restore. Ignored + by restore_on_create. + """ + + def __init__(self, var_list=None): + self._saver = _saver.Saver(var_list=var_list) + + def save(self, save_path, global_step=None): + """Saves variables. + + Args: + save_path: See save method in tf.train.Saver. + global_step: See save method in tf.train.Saver. + + Returns: + See save method in tf.train.Saver. + """ + return self._saver.save(None, save_path, global_step=global_step) + + def restore(self, save_path): + """Restores previously saved variables. + + Args: + save_path: See restore method in tf.train.Saver. + """ + self._saver.restore(None, save_path) + + @contextlib.contextmanager + def maybe_restore_on_create(self, save_path): + """ContextManager that restores variables on creation. + + When save_path is None (e.g. No checkpoint), does nothing. + Otherwise, it preloads all values from checkpoint. When the + corresponding variable is first created, it assigns the checkpoint + value to the variable. + + Args: + save_path: Same as save_path of retore. If None, do not restore. + + Yields: + Nothing. + + Raises: + NotFoundError: If the variable is not found in checkpoint. + """ + if save_path: + ckpt_var_cache = dict() + reader = checkpoint_utils.load_checkpoint(save_path) + for k, _ in checkpoint_utils.list_variables(save_path): + ckpt_var_cache[k] = reader.get_tensor(k) + + old_init = getattr( + resource_variable_ops.ResourceVariable, "_init_from_args", None) + assert old_init, "ResourceVariable misses _init_from_args method." + setattr(resource_variable_ops.ResourceVariable, "_init_from_args", + _init_from_checkpoint) + setattr(resource_variable_ops.ResourceVariable, "old_init", old_init) + setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache", + ckpt_var_cache) + try: + yield + except Exception as e: + raise e + finally: + if save_path: + setattr(resource_variable_ops.ResourceVariable, "_init_from_args", + old_init) + setattr(resource_variable_ops.ResourceVariable, "old_init", None) + setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache", None) diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py new file mode 100644 index 00000000000..b8ff566ec2e --- /dev/null +++ b/tensorflow/contrib/eager/python/saver_test.py @@ -0,0 +1,88 @@ +"""Tests for eager mode Saver.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.eager.python import saver as _saver +from tensorflow.python.eager import context +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import test + + +class SaverTest(test.TestCase): + + def testBasics(self): + with context.eager_mode(): + v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') + def model(): + return array_ops.constant(2.0) * v1 + + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + + _ = model() + saver = _saver.Saver() + saver.save(ckpt_prefix) + v1.assign(2.0) + self.assertEqual(v1.read_value().numpy(), 2.0) + + saver.restore(ckpt_prefix) + self.assertEqual(v1.read_value().numpy(), 1.0) + + def testRestoreOnCreate(self): + with context.eager_mode(): + def model(init_val): + v1 = resource_variable_ops.ResourceVariable(init_val, name='v1') + return array_ops.constant(1.0) * v1 + + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + _ = model(1.0) + saver = _saver.Saver() + saver.save(ckpt_prefix) + + with ops.Graph().as_default(): + saver = _saver.Saver() + with saver.maybe_restore_on_create(ckpt_prefix): + # Value is from checkpoint, but not from argument. + ret = model(2.0) + self.assertEqual(ret.numpy(), 1.0) + # Create it a second time won't re-assign the checkpoint value. + v1_2 = resource_variable_ops.ResourceVariable(3.0, name='v1') + self.assertEqual(v1_2.read_value().numpy(), 3.0) + + def testRestoreNotFound(self): + with context.eager_mode(): + def model(v): + return array_ops.constant(1.0) * v + + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + _ = model(resource_variable_ops.ResourceVariable(1.0, name='v1')) + saver = _saver.Saver() + saver.save(ckpt_prefix) + + with self.assertRaisesRegexp(errors.NotFoundError, + 'v2 not found in checkpoint'): + with saver.maybe_restore_on_create(ckpt_prefix): + _ = model(resource_variable_ops.ResourceVariable(1.0, name='v2')) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index aa0276dfd91..2c7494a0a86 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -42,6 +42,8 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@inf_nan_callback @@nan_callback @@seterr + +@@Saver """ from __future__ import absolute_import @@ -51,6 +53,7 @@ from __future__ import print_function # pylint:disable=g-bad-import-order,g-import-not-at-top,unused-import # +from tensorflow.contrib.eager.python.saver import Saver from tensorflow.python.util.all_util import remove_undocumented from tensorflow.python.eager import backprop from tensorflow.python.eager.custom_gradient import custom_gradient diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD new file mode 100644 index 00000000000..46cdf086ddc --- /dev/null +++ b/tensorflow/contrib/estimator/BUILD @@ -0,0 +1,61 @@ +package( + default_visibility = [ + "//tensorflow:internal", + ], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "estimator_py", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":extenders", + ], +) + +py_library( + name = "extenders", + srcs = [ + "python/estimator/extenders.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:util", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:util", + ], +) + +py_test( + name = "extenders_test", + size = "small", + srcs = ["python/estimator/extenders_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":extenders", + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:metrics", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:linear", + "//tensorflow/python/feature_column", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py new file mode 100644 index 00000000000..9180a3acc36 --- /dev/null +++ b/tensorflow/contrib/estimator/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental utilities re:tf.estimator.*.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.estimator.python.estimator.extenders import * + +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = ['add_metrics'] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py new file mode 100644 index 00000000000..45dd9ef70dd --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/extenders.py @@ -0,0 +1,124 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Extenders of tf.estimator.Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator import util as estimator_util +from tensorflow.python.util import tf_inspect + +_VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config']) + + +def add_metrics(estimator, metric_fn): + """Creates new ${tf.estimator.Esitmator} which has given metrics. + + Example: + + ```python + def my_auc(labels, predictions): + return {'auc': tf.metrics.auc(labels, predictions['logistic'])} + + estimator = tf.estimator.DNNClassifier(...) + estimator = tf.contrib.estimator.add_metrics(estimator, my_auc) + estimator.train(...) + estimator.evaluate(...) + ``` + Example usage of custom metric which uses features: + + ```python + def my_auc(features, labels, predictions): + return {'auc': tf.metrics.auc( + labels, predictions['logistic'], weights=features['weight'])} + + estimator = tf.estimator.DNNClassifier(...) + estimator = tf.contrib.estimator.add_metrics(estimator, my_auc) + estimator.train(...) + estimator.evaluate(...) + ``` + + Args: + estimator: A ${tf.estimator.Esitmator} object. + metric_fn: A function which should obey the following signature: + - Args: can only have following four arguments in any order: + * predictions: Predictions `Tensor` or dict of `Tensor` created by given + `estimator`. + * features: Input `dict` of `Tensor` objects created by `input_fn` which + is given to `estimator.evaluate` as an argument. + * labels: Labels `Tensor` or dict of `Tensor` created by `input_fn` + which is given to `estimator.evaluate` as an argument. + * config: config attribute of the `estimator`. + - Returns: + Dict of metric results keyed by name. Final metrics are a union of this + and `estimator's` existing metrics. If there is a name conflict between + this and `estimator`s existing metrics, this will override the existing + one. The values of the dict are the results of calling a metric + function, namely a `(metric_tensor, update_op)` tuple. + + Returns: + A new ${tf.estimator.Estimator} which has a union of original metrics with + given ones. + """ + _verify_metric_fn_args(metric_fn) + + def new_model_fn(features, labels, mode): + spec = _get_model_fn(estimator)(features, labels, mode) + if mode != model_fn_lib.ModeKeys.EVAL: + return spec + new_metrics = _call_metric_fn(metric_fn, features, labels, spec.predictions, + estimator.config) + all_metrics = spec.eval_metric_ops or {} + all_metrics.update(new_metrics) + return spec._replace(eval_metric_ops=all_metrics) + + return estimator_lib.Estimator( + model_fn=new_model_fn, + model_dir=estimator.model_dir, + config=estimator.config) + + +# TODO(ispir): Move this to tf.estimator.Estimator. +def _get_model_fn(estimator): + return estimator._call_model_fn # pylint: disable=protected-access + + +def _verify_metric_fn_args(metric_fn): + args = set(estimator_util.fn_args(metric_fn)) + if tf_inspect.ismethod(metric_fn): + if 'self' in args: + args.remove('self') + invalid_args = list(args - _VALID_METRIC_FN_ARGS) + if invalid_args: + raise ValueError('metric_fn (%s) has following not expected args: %s' % + (metric_fn, invalid_args)) + + +def _call_metric_fn(metric_fn, features, labels, predictions, config): + """Calls metric fn with proper arguments.""" + metric_fn_args = estimator_util.fn_args(metric_fn) + kwargs = {} + if 'features' in metric_fn_args: + kwargs['features'] = features + if 'labels' in metric_fn_args: + kwargs['labels'] = labels + if 'predictions' in metric_fn_args: + kwargs['predictions'] = predictions + if 'config' in metric_fn_args: + kwargs['config'] = config + return metric_fn(**kwargs) diff --git a/tensorflow/contrib/estimator/python/estimator/extenders_test.py b/tensorflow/contrib/estimator/python/estimator/extenders_test.py new file mode 100644 index 00000000000..422c16d24e9 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/extenders_test.py @@ -0,0 +1,135 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""extenders tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.estimator.python.estimator import extenders +from tensorflow.python.estimator import run_config +from tensorflow.python.estimator.canned import linear +from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.platform import test + + +def get_input_fn(x, y): + + def input_fn(): + dataset = dataset_ops.Dataset.from_tensor_slices({'x': x, 'y': y}) + iterator = dataset.make_one_shot_iterator() + features = iterator.get_next() + labels = features.pop('y') + return features, labels + + return input_fn + + +class AddMetricsTest(test.TestCase): + + def test_should_add_metrics(self): + input_fn = get_input_fn( + x=np.arange(4)[:, None, None], y=np.ones(4)[:, None]) + estimator = linear.LinearClassifier([fc.numeric_column('x')]) + + def metric_fn(features): + return {'mean_x': metrics_lib.mean(features['x'])} + + estimator = extenders.add_metrics(estimator, metric_fn) + + estimator.train(input_fn=input_fn) + metrics = estimator.evaluate(input_fn=input_fn) + self.assertIn('mean_x', metrics) + self.assertEqual(1.5, metrics['mean_x']) + # assert that it keeps original estimators metrics + self.assertIn('auc', metrics) + + def test_should_error_out_for_not_recognized_args(self): + estimator = linear.LinearClassifier([fc.numeric_column('x')]) + + def metric_fn(features, not_recognized): + _, _ = features, not_recognized + return {} + + with self.assertRaisesRegexp(ValueError, 'not_recognized'): + estimator = extenders.add_metrics(estimator, metric_fn) + + def test_all_supported_args(self): + input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]]) + estimator = linear.LinearClassifier([fc.numeric_column('x')]) + + def metric_fn(features, predictions, labels, config): + self.assertIn('x', features) + self.assertIsNotNone(labels) + self.assertIn('logistic', predictions) + self.assertTrue(isinstance(config, run_config.RunConfig)) + return {} + + estimator = extenders.add_metrics(estimator, metric_fn) + + estimator.train(input_fn=input_fn) + estimator.evaluate(input_fn=input_fn) + + def test_all_supported_args_in_different_order(self): + input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]]) + estimator = linear.LinearClassifier([fc.numeric_column('x')]) + + def metric_fn(labels, config, features, predictions): + self.assertIn('x', features) + self.assertIsNotNone(labels) + self.assertIn('logistic', predictions) + self.assertTrue(isinstance(config, run_config.RunConfig)) + return {} + + estimator = extenders.add_metrics(estimator, metric_fn) + + estimator.train(input_fn=input_fn) + estimator.evaluate(input_fn=input_fn) + + def test_all_args_are_optional(self): + input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]]) + estimator = linear.LinearClassifier([fc.numeric_column('x')]) + + def metric_fn(): + return {'two': metrics_lib.mean(constant_op.constant([2.]))} + + estimator = extenders.add_metrics(estimator, metric_fn) + + estimator.train(input_fn=input_fn) + metrics = estimator.evaluate(input_fn=input_fn) + self.assertEqual(2., metrics['two']) + + def test_overrides_existing_metrics(self): + input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]]) + estimator = linear.LinearClassifier([fc.numeric_column('x')]) + estimator.train(input_fn=input_fn) + metrics = estimator.evaluate(input_fn=input_fn) + self.assertNotEqual(2., metrics['auc']) + + def metric_fn(): + return {'auc': metrics_lib.mean(constant_op.constant([2.]))} + + estimator = extenders.add_metrics(estimator, metric_fn) + metrics = estimator.evaluate(input_fn=input_fn) + self.assertEqual(2., metrics['auc']) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD index f5d21278db8..9b34cf1bdb0 100644 --- a/tensorflow/contrib/fused_conv/BUILD +++ b/tensorflow/contrib/fused_conv/BUILD @@ -60,12 +60,14 @@ tf_kernel_library( srcs = [ "kernels/fused_conv2d_bias_activation_op.cc", "kernels/fused_conv2d_bias_activation_op.h", + "kernels/fused_conv_ops_gpu.h", ], prefix = "fused_conv2d_bias_activation_op", deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:stream_executor", "//tensorflow/core/kernels:bounds_check_lib", "//tensorflow/core/kernels:conv_2d_hdrs", "//tensorflow/core/kernels:conv_ops_gpu_hdrs", @@ -81,6 +83,7 @@ tf_custom_op_library( srcs = [ "kernels/fused_conv2d_bias_activation_op.cc", "kernels/fused_conv2d_bias_activation_op.h", + "kernels/fused_conv_ops_gpu.h", "ops/fused_conv2d_bias_activation_op.cc", ], deps = [ @@ -94,12 +97,8 @@ tf_custom_op_library( ) tf_gen_op_libs( - op_lib_names = [ - "fused_conv2d_bias_activation_op", - ], - deps = [ - "//tensorflow/core:lib_proto_parsing", - ], + op_lib_names = ["fused_conv2d_bias_activation_op"], + deps = ["//tensorflow/core:lib_proto_parsing"], ) tf_gen_op_wrapper_py( @@ -109,7 +108,7 @@ tf_gen_op_wrapper_py( cuda_py_test( name = "fused_conv2d_bias_activation_op_test", - size = "small", + size = "large", srcs = ["python/ops/fused_conv2d_bias_activation_op_test.py"], additional_deps = [ ":fused_conv_py", diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index dc0701b234f..675ff2be388 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#define EIGEN_USE_THREADS - #if GOOGLE_CUDA #define EIGEN_USE_GPU #endif // GOOGLE_CUDA @@ -31,8 +29,8 @@ limitations under the License. #include "tensorflow/core/kernels/conv_2d.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/util/padding.h" -#include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/use_cudnn.h" #if GOOGLE_CUDA @@ -40,38 +38,84 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/util/activation_mode.h" #endif // GOOGLE_CUDA + namespace tensorflow { -typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -template -struct LaunchConvOp; +template +struct RawType { + using type = T; +}; -template +template <> +struct RawType { + using type = int8; +}; + +// Template struct to convert int8x4 to int32. +// (for NCHW_VECT_C with element type int8, we can consider it to be +// an NCHW layout with element type int32 for operations like padding). +template +struct Int8x4ToInt32 { + // By default, do not change T. + using type = T; +}; + +template <> +struct Int8x4ToInt32 { + using type = int32; +}; + +// T is the element type of the conv_input, filter and side_input tensors. +// BiasType is the element type of the bias tensor, which can be different. +// ScaleType is the type used for conv_input_scale, side_input_scale. +template class FusedConv2DBiasActivationOp : public OpKernel { public: explicit FusedConv2DBiasActivationOp(OpKernelConstruction* context) : OpKernel(context) { - string data_format; - OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); - OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + string data_format_str, filter_format_str; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format")); + OP_REQUIRES_OK(context, + context->GetAttr("filter_format", &filter_format_str)); OP_REQUIRES(context, - (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW), - errors::InvalidArgument("Current implementation only supports " - "NHWC and NCHW data formats.")); - OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); - OP_REQUIRES(context, strides_.size() == 4, + FilterFormatFromString(filter_format_str, &filter_format_), + errors::InvalidArgument("Invalid filter format")); + + std::vector strides; + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides)); + OP_REQUIRES(context, strides.size() == 4, errors::InvalidArgument("Sliding window strides field must " "specify 4 dimensions")); + + stride_rows_ = GetTensorDim(strides, data_format_, 'H'); + stride_cols_ = GetTensorDim(strides, data_format_, 'W'); OP_REQUIRES( context, - (GetTensorDim(strides_, data_format_, 'N') == 1 && - GetTensorDim(strides_, data_format_, 'C') == 1), - errors::InvalidArgument("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + (GetTensorDim(strides, data_format_, 'N') == 1 && + GetTensorDim(strides, data_format_, 'C') == 1), + errors::InvalidArgument("Convolutional strides are not supported in " + "the batch or depth dimensions.")); + + // Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here. + constexpr bool is_int8x4 = std::is_same::value; + + // Note: Only NCHW_VECT_C format is supported for int8. + // This is because it is expected to be the fastest, and our previous tests + // found cudnn 6 does not fully support the other formats for int8 mode. + OP_REQUIRES(context, (is_int8x4 == (data_format_ == FORMAT_NCHW_VECT_C)), + errors::InvalidArgument( + "qint8 should be used with data_format NCHW_VECT_C.")); + + OP_REQUIRES(context, (is_int8x4 == (filter_format_ == FORMAT_OIHW_VECT_I)), + errors::InvalidArgument( + "qint8 should be used with filter_format OIHW_VECT_I.")); + + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_type_)); + eigen_padding_type_ = BrainPadding2EigenPadding(padding_type_); string activation_mode_str; OP_REQUIRES_OK(context, context->GetAttr("activation_mode", &activation_mode_str)); @@ -79,130 +123,111 @@ class FusedConv2DBiasActivationOp : public OpKernel { &activation_mode_)); OP_REQUIRES(context, activation_mode_ == ActivationMode::RELU, errors::InvalidArgument("Current implementation only supports " - "relu as the activation mode.")); + "RELU as the activation function.")); cudnn_use_autotune_ = CudnnUseAutotune(); + float conv_input_scale_flt, side_input_scale_flt; + OP_REQUIRES_OK(context, + context->GetAttr("conv_input_scale", &conv_input_scale_flt)); + OP_REQUIRES_OK(context, + context->GetAttr("side_input_scale", &side_input_scale_flt)); + conv_input_scale_ = conv_input_scale_flt; + side_input_scale_ = side_input_scale_flt; + } + + Status CheckShape(const Tensor& tensor, const string& tensor_name) { + const int num_dims = tensor.dims(); + for (int i = 0; i < num_dims; i++) { + if (!FastBoundsCheck(tensor.dim_size(i), + std::numeric_limits::max())) { + return errors::InvalidArgument(tensor_name, " dimension ", i, + " too large"); + } + } + // If there is a 5th dimension it is the VECT_C or VECT_I dimension. + if (num_dims == 5 && tensor.dim_size(4) != 4) { + return errors::InvalidArgument("The last dimension of ", tensor_name, + " must be of size 4 for qint8."); + } + return Status::OK(); } void Compute(OpKernelContext* context) override { - // Input tensor is one of the following shapes: - // [ batch, in_rows, in_cols, in_depth ] (for NHWC data format) - // [ batch, in_depth, in_rows, in_cols ] (for NCHW data format) - const Tensor& input = context->input(0); + // The conv_input tensor is one of the following formats: + // NHWC, NCHW, NCHW_VECT_C. + const Tensor& conv_input = context->input(0); + OP_REQUIRES_OK(context, CheckShape(conv_input, "conv_input")); - // Input filter is of the following dimensions: - // [ filter_rows, filter_cols, in_depth, out_depth ] + // The filter tensor is one of the following formats: + // HWIO, OIHW, OIHW_VECT_I. const Tensor& filter = context->input(1); + OP_REQUIRES_OK(context, CheckShape(filter, "filter")); - // Input bias is a 1-D tensor the size of the last - // dimension of Output tensor + // Input bias is a 1-D tensor, with size matching output depth. const Tensor& bias = context->input(2); + OP_REQUIRES_OK(context, CheckShape(bias, "conv_input")); - // For 2D convolution, there should be 4 dimensions. - OP_REQUIRES(context, input.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional", - input.shape().DebugString())); - OP_REQUIRES(context, filter.dims() == 4, - errors::InvalidArgument("filter must be 4-dimensional: ", - filter.shape().DebugString())); - - // Bias should be a 1-D tensor. - OP_REQUIRES(context, bias.dims() == 1, - errors::InvalidArgument("bias must be 1-dimensional: ", - bias.shape().DebugString())); - - for (int i = 0; i < 4; i++) { - OP_REQUIRES(context, - FastBoundsCheck(filter.dim_size(i), - std::numeric_limits::max()), - errors::InvalidArgument("filter dimension too large")); - OP_REQUIRES( - context, - FastBoundsCheck(input.dim_size(i), std::numeric_limits::max()), - errors::InvalidArgument("input dimension too large")); + // If side_input_scale != 0, then side_input is not ignored and + // has the same type and dimensions as the output. + const Tensor& side_input = context->input(3); + if (side_input_scale_ != 0) { + OP_REQUIRES_OK(context, CheckShape(side_input, "side_input")); } - // The last dimension for input is in_depth. It must be the same as the - // filter's in_depth. - const int64 in_depth = GetTensorDim(input, data_format_, 'C'); - OP_REQUIRES(context, in_depth == filter.dim_size(2), - errors::InvalidArgument( - "input and filter must have the same depth: ", in_depth, - " vs ", filter.dim_size(2))); + // TODO(pauldonnelly): Switch to a more efficient mechanism to access + // dimension indexes and per-dimension attributes. + const int32 filter_rows = GetFilterDim(filter, filter_format_, 'H'); + const int32 filter_cols = GetFilterDim(filter, filter_format_, 'W'); + const int32 output_depth = GetFilterDim(filter, filter_format_, 'O'); - // The last dimension for filter is out_depth. - const int32 out_depth = static_cast(filter.dim_size(3)); + const int32 batch_size = GetTensorDim(conv_input, data_format_, 'N'); + const int32 conv_input_rows = GetTensorDim(conv_input, data_format_, 'H'); + const int32 conv_input_cols = GetTensorDim(conv_input, data_format_, 'W'); - // The second dimension for input is rows/height. - // The first dimension for filter is rows/height. - const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H'); - const int32 input_rows = static_cast(input_rows_raw); - const int32 filter_rows = static_cast(filter.dim_size(0)); - - // The third dimension for input is columns/width. - // The second dimension for filter is columns/width. - const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W'); - const int32 input_cols = static_cast(input_cols_raw); - const int32 filter_cols = static_cast(filter.dim_size(1)); - - // The first dimension for input is batch. - const int64 batch_raw = GetTensorDim(input, data_format_, 'N'); - const int32 batch = static_cast(batch_raw); - - // For now we take the stride from the second and third dimensions only (we - // do not support striding on the batch or depth dimension). - const int32 stride_rows = - static_cast(GetTensorDim(strides_, data_format_, 'H')); - const int32 stride_cols = - static_cast(GetTensorDim(strides_, data_format_, 'W')); - const int32 bias_size = static_cast(bias.dim_size(0)); - - int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; - OP_REQUIRES_OK(context, - GetWindowedOutputSize(input_rows, filter_rows, stride_rows, - padding_, &out_rows, &pad_rows)); - OP_REQUIRES_OK(context, - GetWindowedOutputSize(input_cols, filter_cols, stride_cols, - padding_, &out_cols, &pad_cols)); - // Output tensor is of the following dimensions: - // [ in_batch, out_rows, out_cols, out_depth ] - TensorShape out_shape = - ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth); + int64 output_rows = 0, output_cols = 0, pad_rows = 0, pad_cols = 0; + OP_REQUIRES_OK(context, GetWindowedOutputSize(conv_input_rows, filter_rows, + stride_rows_, padding_type_, + &output_rows, &pad_rows)); + OP_REQUIRES_OK(context, GetWindowedOutputSize(conv_input_cols, filter_cols, + stride_cols_, padding_type_, + &output_cols, &pad_cols)); + // Initialize the output tensor shape according to data_format_ + TensorShape output_shape = ShapeFromFormat( + data_format_, batch_size, output_rows, output_cols, output_depth); Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); - // Bias size should be the same as the size of the channel dimension of - // output. - OP_REQUIRES(context, bias_size == out_depth, - errors::InvalidArgument( - "bias size should equal the channel " - "dimension size of output. bias shape: ", - bias.shape().DebugString() + - ", output shape: " + output->shape().DebugString())); - - VLOG(2) << "FusedConv2DBiasActivation: in_depth = " << in_depth - << ", input_cols = " << input_cols + VLOG(2) << "FusedConv2DBiasActivation: conv_input_cols = " + << conv_input_cols << ", conv_input_rows = " << conv_input_rows << ", filter_cols = " << filter_cols - << ", input_rows = " << input_rows << ", filter_rows = " << filter_rows - << ", stride_rows = " << stride_rows - << ", stride_cols = " << stride_cols - << ", bias_size = " << bias_size << ", out_depth = " << out_depth; + << ", stride_cols = " << stride_cols_ + << ", stride_rows = " << stride_rows_ + << ", output_depth = " << output_depth + << ", output_cols = " << output_cols + << ", output_rows = " << output_rows + << ", output_shape.num_elements = " << output_shape.num_elements(); // If there is nothing to compute, return. - if (out_shape.num_elements() == 0) { + if (output_shape.num_elements() == 0) { return; } - launcher_.launch(context, cudnn_use_autotune_, input, filter, stride_rows, - stride_cols, bias, activation_mode_, - BrainPadding2EigenPadding(padding_), data_format_, output); + + launcher_.launch(context, cudnn_use_autotune_, conv_input, + conv_input_scale_, filter, stride_rows_, stride_cols_, + eigen_padding_type_, side_input, side_input_scale_, bias, + activation_mode_, data_format_, filter_format_, output); } private: - std::vector strides_; - Padding padding_; + int32 stride_rows_, stride_cols_; + Padding padding_type_; + Eigen::PaddingType eigen_padding_type_; ActivationMode activation_mode_; TensorFormat data_format_; - LaunchFusedConv2DBiasActivationOp launcher_; + FilterTensorFormat filter_format_; + ScaleType conv_input_scale_; + ScaleType side_input_scale_; + LaunchFusedConv2DBiasActivationOp launcher_; bool cudnn_use_autotune_; TF_DISALLOW_COPY_AND_ASSIGN(FusedConv2DBiasActivationOp); @@ -211,67 +236,72 @@ class FusedConv2DBiasActivationOp : public OpKernel { #if GOOGLE_CUDA namespace dnn = ::perftools::gputools::dnn; -dnn::ActivationMode BrainActivationMode2CudnnActivationMode( - ActivationMode activation_mode) { - switch (activation_mode) { - case ActivationMode::SIGMOID: - return dnn::ActivationMode::kSigmoid; - case ActivationMode::RELU: - return dnn::ActivationMode::kRelu; - case ActivationMode::RELUX: - return dnn::ActivationMode::kReluX; - case ActivationMode::RELU6: - return dnn::ActivationMode::kRelu6; - case ActivationMode::TANH: - return dnn::ActivationMode::kTanh; - case ActivationMode::BANDPASS: - return dnn::ActivationMode::kBandPass; - } - // Prevent compiler warning about missing return - return dnn::ActivationMode::kRelu; -} - // A dummy type to group forward convolution autotune results together. struct ConvBiasActivationAutoTuneGroup { static string name() { return "ConvBiasActivation"; } }; -typedef AutoTuneSingleton +typedef AutoTuneSingleton AutoTuneConvBiasActivation; -template -void LaunchFusedConv2DBiasActivationOp::launch( - OpKernelContext* ctx, bool cudnn_use_autotune, const Tensor& input_param, - const Tensor& filter, int32 row_stride, int32 col_stride, - const Tensor& bias, const ActivationMode& activation_mode, - const Eigen::PaddingType& padding, TensorFormat data_format, - Tensor* output) { - using perftools::gputools::dnn::AlgorithmConfig; - using perftools::gputools::dnn::AlgorithmType; - using perftools::gputools::dnn::ProfileResult; - using perftools::gputools::dnn::kDefaultAlgorithm; +// Allocates 'transformed_tensor' and transforms 'nhwc_tensor' into it +// using the specified 'batch_size', 'rows', 'cols', and 'depth' dimensions. +template +Status TransformNHWCToNCHW(OpKernelContext* ctx, const Tensor& nhwc_tensor, + int batch_size, int rows, int cols, int depth, + Tensor* transformed_tensor, const Tensor** result) { + TensorShape nchw_shape = + ShapeFromFormat(FORMAT_NCHW, batch_size, rows, cols, depth); + if (depth > 1) { + TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum::value, nchw_shape, + transformed_tensor)); + functor::NHWCToNCHW()( + ctx->eigen_device(), nhwc_tensor.tensor(), + transformed_tensor->tensor()); + } else { + // If depth <= 1, then just reshape. + CHECK(transformed_tensor->CopyFrom(nhwc_tensor, nchw_shape)); + } + *result = transformed_tensor; + return Status::OK(); +} + +template +void LaunchFusedConv2DBiasActivationOp:: + launch(OpKernelContext* ctx, bool cudnn_use_autotune, + const Tensor& conv_input_param, ScaleType conv_input_scale, + const Tensor& filter_param, int32 row_stride, int32 col_stride, + const Eigen::PaddingType& padding, const Tensor& side_input_param, + ScaleType side_input_scale, const Tensor& bias, + ActivationMode activation_mode, TensorFormat data_format, + FilterTensorFormat filter_format, Tensor* output_param) { auto* stream = ctx->op_device_context()->stream(); OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); - Tensor input = input_param; - - perftools::gputools::dnn::ActivationMode cudnn_activation_mode = - BrainActivationMode2CudnnActivationMode(activation_mode); - // TODO(yangzihao): refactor all the complicated/duplicated code in regular // conv ops to a shared conv utility. - int32 padding_rows = 0; - int32 padding_cols = 0; - const int64 in_batch = GetTensorDim(input, data_format, 'N'); - int64 in_rows = GetTensorDim(input, data_format, 'H'); - int64 in_cols = GetTensorDim(input, data_format, 'W'); - const int64 in_depths = GetTensorDim(input, data_format, 'C'); - const int64 out_batch = GetTensorDim(*output, data_format, 'N'); - const int64 out_rows = GetTensorDim(*output, data_format, 'H'); - const int64 out_cols = GetTensorDim(*output, data_format, 'W'); - const int64 out_depths = GetTensorDim(*output, data_format, 'C'); - const int64 patch_rows = filter.dim_size(0); - const int64 patch_cols = filter.dim_size(1); + + // Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here. + constexpr bool is_int8x4 = std::is_same::value; + constexpr int rank = is_int8x4 ? 5 : 4; + constexpr int vect = is_int8x4 ? 4 : 1; + + const int batch_size = GetTensorDim(conv_input_param, data_format, 'N'); + int conv_input_rows = GetTensorDim(conv_input_param, data_format, 'H'); + int conv_input_cols = GetTensorDim(conv_input_param, data_format, 'W'); + + const int conv_input_depth = + GetTensorDim(conv_input_param, data_format, 'C') * vect; + const int output_rows = GetTensorDim(*output_param, data_format, 'H'); + const int output_cols = GetTensorDim(*output_param, data_format, 'W'); + const int output_depth = GetFilterDim(filter_param, filter_format, 'O'); + const int filter_rows = GetFilterDim(filter_param, filter_format, 'H'); + const int filter_cols = GetFilterDim(filter_param, filter_format, 'W'); + int padding_rows = 0; + int padding_cols = 0; + const Tensor* conv_input = &conv_input_param; + + Tensor maybe_padded_conv_input; if (padding == Eigen::PADDING_SAME) { // Total padding on rows and cols is // Pr = (R' - 1) * S + Kr - R @@ -281,114 +311,152 @@ void LaunchFusedConv2DBiasActivationOp::launch( // We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top // and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means // we pad more on the right and bottom than on the top and left. - padding_rows = - std::max(0, (out_rows - 1) * row_stride + patch_rows - in_rows); - padding_cols = - std::max(0, (out_cols - 1) * col_stride + patch_cols - in_cols); - const int rows_parity = padding_rows & 1; - const int cols_parity = padding_cols & 1; - if ((rows_parity | cols_parity) != 0) { + padding_rows = std::max( + 0, (output_rows - 1) * row_stride + filter_rows - conv_input_rows); + padding_cols = std::max( + 0, (output_cols - 1) * col_stride + filter_cols - conv_input_cols); + const int padding_rows_parity = padding_rows & 1; + const int padding_cols_parity = padding_cols & 1; + if ((padding_rows_parity | padding_cols_parity) != 0) { Tensor transformed_input; - int64 new_in_rows = in_rows + rows_parity; - int64 new_in_cols = in_cols + cols_parity; + const int new_conv_input_rows = conv_input_rows + padding_rows_parity; + const int new_conv_input_cols = conv_input_cols + padding_cols_parity; + + using VectT = typename Int8x4ToInt32::type>::type; + auto pad_data_format = is_int8x4 ? FORMAT_NCHW : data_format; + OP_REQUIRES_OK( - ctx, - ctx->allocate_temp(DataTypeToEnum::value, - ShapeFromFormat(data_format, in_batch, new_in_rows, - new_in_cols, in_depths), - &transformed_input)); + ctx, ctx->allocate_temp( + DataTypeToEnum::value, + ShapeFromFormat(data_format, batch_size, new_conv_input_rows, + new_conv_input_cols, conv_input_depth), + &maybe_padded_conv_input)); - functor::PadInput()( - ctx->eigen_device(), To32Bit(input_param.tensor()), - {{0, 0}}, {{rows_parity, cols_parity}}, - To32Bit(transformed_input.tensor()), data_format); + auto conv_input_eigen_tensor = + To32Bit(conv_input_param.reinterpret_last_dimension()); + auto padded_conv_input_eigen_tensor = To32Bit( + maybe_padded_conv_input.reinterpret_last_dimension()); - input = transformed_input; - in_rows = new_in_rows; - in_cols = new_in_cols; + functor::PadInput()( + ctx->eigen_device(), conv_input_eigen_tensor, {{0, 0}}, + {{padding_rows_parity, padding_cols_parity}}, + padded_conv_input_eigen_tensor, pad_data_format); + + conv_input = &maybe_padded_conv_input; + conv_input_rows = new_conv_input_rows; + conv_input_cols = new_conv_input_cols; } } - if (data_format == FORMAT_NHWC) { - // Convert the input tensor from NHWC to NCHW. - TensorShape nchw_shape = - ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths); - if (in_depths > 1) { - Tensor transformed_input; - OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::value, - nchw_shape, &transformed_input)); - functor::NHWCToNCHW()( - ctx->eigen_device(), - const_cast(input).tensor(), - transformed_input.tensor()); - input = transformed_input; - } else { - // If depth <= 1, then just reshape. - CHECK(input.CopyFrom(input, nchw_shape)); + Tensor maybe_transformed_conv_input, maybe_transformed_side_input; + Tensor maybe_transformed_output; + const Tensor* side_input = &side_input_param; + Tensor* output = output_param; + + // NOTE: Here and elsewhere, checking 'is_int8x4' may look unnecessary + // and inefficient, but it is actually both a time and code size optimization, + // since 'is_int8x4' is a constexpr determined by the template parameter. + if (!is_int8x4 && data_format == FORMAT_NHWC) { + OP_REQUIRES_OK(ctx, (TransformNHWCToNCHW( + ctx, *conv_input, batch_size, conv_input_rows, + conv_input_cols, conv_input_depth, + &maybe_transformed_conv_input, &conv_input))); + if (side_input_scale != 0) { + OP_REQUIRES_OK( + ctx, (TransformNHWCToNCHW( + ctx, side_input_param, batch_size, output_rows, output_cols, + output_depth, &maybe_transformed_side_input, &side_input))); + } + if (output_depth > 1) { + // Allocate a tensor for the NCHW output of the kernel and point output + // to it. Afterwards, we will transform it to NHWC while copying back to + // 'output_param'. + TensorShape nchw_shape = ShapeFromFormat( + FORMAT_NCHW, batch_size, output_rows, output_cols, output_depth); + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DataTypeToEnum::value, nchw_shape, + &maybe_transformed_output)); + output = &maybe_transformed_output; } } - CHECK(padding_rows >= 0 && padding_cols >= 0) - << "Negative row or col paddings: (" << padding_rows << ", " - << padding_cols << ")"; - perftools::gputools::dnn::BatchDescriptor input_desc; - input_desc.set_count(in_batch) - .set_feature_map_count(in_depths) - .set_height(in_rows) - .set_width(in_cols) - .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); - perftools::gputools::dnn::BatchDescriptor output_desc; - output_desc.set_count(out_batch) - .set_height(out_rows) - .set_width(out_cols) - .set_feature_map_count(out_depths) - .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); - perftools::gputools::dnn::FilterDescriptor filter_desc; - filter_desc.set_input_filter_height(filter.dim_size(0)) - .set_input_filter_width(filter.dim_size(1)) - .set_input_feature_map_count(filter.dim_size(2)) - .set_output_feature_map_count(filter.dim_size(3)); - perftools::gputools::dnn::ConvolutionDescriptor conv_desc; + constexpr auto data_layout = is_int8x4 ? dnn::DataLayout::kBatchDepthYX4 + : dnn::DataLayout::kBatchDepthYX; + constexpr auto filter_layout = is_int8x4 ? dnn::FilterLayout::kOutputInputYX4 + : dnn::FilterLayout::kOutputInputYX; + + dnn::BatchDescriptor conv_input_desc; + conv_input_desc.set_count(batch_size) + .set_feature_map_count(conv_input_depth) + .set_height(conv_input_rows) + .set_width(conv_input_cols) + .set_layout(data_layout); + dnn::FilterDescriptor filter_desc; + filter_desc.set_input_filter_height(filter_rows) + .set_input_filter_width(filter_cols) + .set_input_feature_map_count(conv_input_depth) + .set_output_feature_map_count(output_depth) + .set_layout(filter_layout); + dnn::BatchDescriptor side_input_desc; + side_input_desc.set_count(batch_size) + .set_height(output_rows) + .set_width(output_cols) + .set_feature_map_count(output_depth) + .set_layout(data_layout); + dnn::BatchDescriptor bias_desc; + bias_desc.set_count(1) + .set_height(1) + .set_width(1) + .set_feature_map_count(output_depth) + .set_layout(dnn::DataLayout::kBatchDepthYX); + dnn::BatchDescriptor output_desc; + output_desc.set_count(batch_size) + .set_height(output_rows) + .set_width(output_cols) + .set_feature_map_count(output_depth) + .set_layout(data_layout); + dnn::ConvolutionDescriptor conv_desc; conv_desc.set_vertical_filter_stride(row_stride) .set_horizontal_filter_stride(col_stride) .set_zero_padding_height(padding_rows / 2) .set_zero_padding_width(padding_cols / 2); - // Shuffles a filter tensor from: - // [, in, out] - // to: - // [out, in, ] - // TODO(yangzihao): Support a data layout tag for the filter weights, and only - // do the transform if the weights are not already in the correct layout. - Tensor transformed_filter; - OP_REQUIRES_OK(ctx, ctx->allocate_temp( - DataTypeToEnum::value, - TensorShape({filter.dim_size(3), filter.dim_size(2), - filter.dim_size(0), filter.dim_size(1)}), - &transformed_filter)); + Tensor maybe_transformed_filter; + const Tensor* filter; + if (is_int8x4) { + // We have already checked filter is OIHW_VECT_I in the constructor. + filter = &filter_param; + } else if (filter_format == FORMAT_HWIO) { + // Shuffle filter tensor from HWIO to OIHW: + OP_REQUIRES_OK(ctx, ctx->allocate_temp( + DataTypeToEnum::value, + ShapeFromFilterFormat( + FORMAT_OIHW, filter_param.shape(), FORMAT_HWIO), + &maybe_transformed_filter)); + functor::TransformFilter()( + ctx->eigen_device(), To32Bit(filter_param.tensor()), + To32Bit(maybe_transformed_filter.tensor())); + filter = &maybe_transformed_filter; + } - functor::TransformFilter()( - ctx->eigen_device(), To32Bit(filter.tensor()), - To32Bit(transformed_filter.tensor())); - - Tensor transformed_output; - OP_REQUIRES_OK( - ctx, ctx->allocate_temp(DataTypeToEnum::value, - ShapeFromFormat(FORMAT_NCHW, out_batch, out_rows, - out_cols, out_depths), - &transformed_output)); - - auto input_ptr = AsDeviceMemory(input.template flat().data(), - input.template flat().size()); + auto conv_input_ptr = + AsDeviceMemory(reinterpret_cast::type*>( + conv_input->template flat().data()), + conv_input->template flat().size()); auto filter_ptr = - AsDeviceMemory(transformed_filter.template flat().data(), - transformed_filter.template flat().size()); + AsDeviceMemory(reinterpret_cast::type*>( + filter->template flat().data()), + filter->template flat().size()); + auto side_input_ptr = + AsDeviceMemory(reinterpret_cast::type*>( + side_input->template flat().data()), + side_input->template flat().size()); auto output_ptr = - AsDeviceMemory(transformed_output.template flat().data(), - transformed_output.template flat().size()); - - auto bias_ptr = AsDeviceMemory(bias.template flat().data(), - bias.template flat().size()); + AsDeviceMemory(reinterpret_cast::type*>( + output->template flat().data()), + output->template flat().size()); + auto bias_ptr = AsDeviceMemory(bias.template flat().data(), + bias.template flat().size()); static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit( // default value is in bytes despite the name of the environment variable @@ -396,38 +464,42 @@ void LaunchFusedConv2DBiasActivationOp::launch( ); int device_id = stream->parent()->device_ordinal(); - DataType dtype = input.dtype(); - ConvParameters conv_parameters = { - in_batch, - in_depths, - {{in_rows, in_cols}}, - out_depths, - {{patch_rows, patch_cols}}, + FusedConvParameters fused_conv_parameters = { + batch_size, + conv_input_depth, + {{conv_input_rows, conv_input_cols}}, + output_depth, + {{filter_rows, filter_cols}}, {{row_stride, col_stride}}, {{padding_rows, padding_cols}}, - dtype, + conv_input->dtype(), device_id, + (side_input_scale != 0), + activation_mode, }; - AlgorithmConfig algorithm_config; + dnn::AlgorithmConfig algorithm_config; if (cudnn_use_autotune && !AutoTuneConvBiasActivation::GetInstance()->Find( - conv_parameters, &algorithm_config)) { - std::vector algorithms; + fused_conv_parameters, &algorithm_config)) { + std::vector algorithms; CHECK(stream->parent()->GetConvolveAlgorithms( - conv_parameters.ShouldIncludeWinogradNonfusedAlgo(), &algorithms)); - ProfileResult best_result; - ProfileResult best_result_no_scratch; + fused_conv_parameters.ShouldIncludeWinogradNonfusedAlgo(), + &algorithms)); + dnn::ProfileResult best_result; + dnn::ProfileResult best_result_no_scratch; for (auto profile_algorithm : algorithms) { // TODO(zhengxq): profile each algorithm multiple times to better // accuracy. CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); - ProfileResult profile_result; + dnn::ProfileResult profile_result; bool cudnn_launch_status = stream - ->ThenConvolveWithAlgorithm( - input_desc, input_ptr, filter_desc, filter_ptr, conv_desc, - bias_ptr, cudnn_activation_mode, output_desc, &output_ptr, - &scratch_allocator, AlgorithmConfig(profile_algorithm), + ->ThenFusedConvolveWithAlgorithm( + conv_input_desc, conv_input_ptr, conv_input_scale, + filter_desc, filter_ptr, conv_desc, side_input_ptr, + side_input_scale, bias_desc, bias_ptr, + dnn::ActivationMode::kRelu, output_desc, &output_ptr, + &scratch_allocator, dnn::AlgorithmConfig(profile_algorithm), &profile_result) .ok(); if (cudnn_launch_status) { @@ -454,42 +526,68 @@ void LaunchFusedConv2DBiasActivationOp::launch( algorithm_config.set_algorithm_no_scratch( best_result_no_scratch.algorithm()); } - AutoTuneConvBiasActivation::GetInstance()->Insert(conv_parameters, + AutoTuneConvBiasActivation::GetInstance()->Insert(fused_conv_parameters, algorithm_config); } CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); bool cudnn_launch_status = stream - ->ThenConvolveWithAlgorithm( - input_desc, input_ptr, filter_desc, filter_ptr, conv_desc, - bias_ptr, cudnn_activation_mode, output_desc, &output_ptr, - &scratch_allocator, algorithm_config, + ->ThenFusedConvolveWithAlgorithm( + conv_input_desc, conv_input_ptr, conv_input_scale, filter_desc, + filter_ptr, conv_desc, side_input_ptr, side_input_scale, + bias_desc, bias_ptr, dnn::ActivationMode::kRelu, output_desc, + &output_ptr, &scratch_allocator, algorithm_config, /*output_profile_result=*/nullptr) .ok(); if (!cudnn_launch_status) { - ctx->SetStatus(errors::Internal( - "cuDNN launch failure : input shape(", input.shape().DebugString(), - ") filter shape(", filter.shape().DebugString(), ")")); + ctx->SetStatus(errors::Internal("cuDNN launch failure : conv_input shape(", + conv_input->shape().DebugString(), + ") filter shape(", + filter->shape().DebugString(), ")")); } - // Convert the output tensor back from NCHW to NHWC. - if (data_format == FORMAT_NHWC) { + // Convert the output tensor back from NCHW to NHWC if necessary. + if (!is_int8x4 && (data_format == FORMAT_NHWC) && (output_depth > 1)) { functor::NCHWToNHWC()( ctx->eigen_device(), - const_cast(transformed_output).tensor(), - output->tensor()); - } else { - *output = transformed_output; + const_cast(output)->tensor(), + output_param->tensor()); } } +// Forward declarations of the functor specializations for GPU used above. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void PadInput::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + const std::array& padding_left, \ + const std::array& padding_right, \ + typename TTypes::Tensor out, TensorFormat data_format); \ + extern template struct PadInput; + +DECLARE_GPU_SPEC(float); +DECLARE_GPU_SPEC(int32); +#undef DECLARE_GPU_SPEC +} // namespace functor + // Registration of the GPU implementations. -REGISTER_KERNEL_BUILDER(Name("FusedConv2DBiasActivation") - .Device(DEVICE_GPU) - .TypeConstraint("T"), - FusedConv2DBiasActivationOp); + +REGISTER_KERNEL_BUILDER( + Name("FusedConv2DBiasActivation") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .TypeConstraint("Tbias"), + FusedConv2DBiasActivationOp); + +REGISTER_KERNEL_BUILDER( + Name("FusedConv2DBiasActivation") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .TypeConstraint("Tbias"), + FusedConv2DBiasActivationOp); #endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h index d71b26cf1db..7534f5797c4 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h @@ -24,7 +24,7 @@ limitations under the License. #if GOOGLE_CUDA #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/kernels/conv_ops_gpu.h" +#include "tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA @@ -33,27 +33,30 @@ namespace tensorflow { // Forward declaration. class OpKernelContext; -template +template class LaunchFusedConv2DBiasActivationOp { public: void launch(OpKernelContext* ctx, bool cudnn_use_autotune, - const Tensor& input, const Tensor& filter, int row_stride, - int col_stride, const Tensor& bias, - const ActivationMode& activation_mode, - const Eigen::PaddingType& padding, TensorFormat data_format, - Tensor* output); + const Tensor& conv_input, ScaleType conv_input_scale, + const Tensor& filter, int32 row_stride, int32 col_stride, + const Eigen::PaddingType& padding, const Tensor& side_input, + ScaleType side_input_scale, const Tensor& bias, + ActivationMode activation_mode, TensorFormat data_format, + FilterTensorFormat filter_format, Tensor* output); }; #ifdef GOOGLE_CUDA -template -class LaunchFusedConv2DBiasActivationOp { +template +class LaunchFusedConv2DBiasActivationOp { public: void launch(OpKernelContext* ctx, bool cudnn_use_autotune, - const Tensor& input, const Tensor& filter, int32 row_stride, - int32 col_stride, const Tensor& bias, - const ActivationMode& activation_mode, - const Eigen::PaddingType& padding, TensorFormat data_format, - Tensor* output); + const Tensor& conv_input, ScaleType conv_input_scale, + const Tensor& filter, int32 row_stride, int32 col_stride, + const Eigen::PaddingType& padding, const Tensor& side_input, + ScaleType side_input_scale, const Tensor& bias, + ActivationMode activation_mode, TensorFormat data_format, + FilterTensorFormat filter_format, Tensor* output); }; #endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h new file mode 100644 index 00000000000..dc43af11580 --- /dev/null +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_ + +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/conv_ops_gpu.h" +#include "tensorflow/core/util/activation_mode.h" + +// TODO(pauldonnelly): Merge this file into core/kernels/conv_ops_gpu.h. + +namespace tensorflow { + +// Add additional parameters specific to fused convolutions. +class FusedConvParameters : public ConvParameters { + public: + FusedConvParameters(int64 batch, int64 in_depths, const SpatialArray& in, + int64 out_depths, const SpatialArray& filter, + const SpatialArray& stride, const SpatialArray& padding, + DataType dtype, int device_id, bool has_side_input, + ActivationMode activation_mode) + : ConvParameters(batch, in_depths, in, out_depths, filter, stride, + padding, dtype, device_id), + activation_mode_(activation_mode), + has_side_input_(has_side_input) { + hash_code_ = Hash64Combine(hash_code_, has_side_input); + hash_code_ = Hash64Combine(hash_code_, activation_mode); + } + + bool operator==(const FusedConvParameters& other) const { + return this->get_data_as_tuple() == other.get_data_as_tuple(); + } + + bool operator!=(const FusedConvParameters& other) const { + return !(*this == other); + } + + string ToString() const { + return strings::StrCat(ConvParameters::ToString(), ", ", has_side_input_, + ", ", activation_mode_, ", "); + } + + private: + using ParameterDataType = + std::tuple; + + ParameterDataType get_data_as_tuple() const { + return std::make_tuple(ConvParameters::get_data_as_tuple(), has_side_input_, + activation_mode_); + } + + ActivationMode activation_mode_; + bool has_side_input_; +}; + +} // namespace tensorflow + +#endif // GOOGLE_CUDA + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_ diff --git a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc index 6134c5c699d..48f058b4c53 100644 --- a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc @@ -33,40 +33,73 @@ string GetAllActivationModeAttrString() { return "activation_mode: {'Relu'}"; } } // namespace // -------------------------------------------------------------------------- + +// TODO(pauldonnelly): Add support for double inputs and scales to this Op, +// (currently Attr does not support double). + REGISTER_OP("FusedConv2DBiasActivation") - .Input("input: T") + .Input("conv_input: T") .Input("filter: T") - .Input("bias: T") + .Input("bias: Tbias") + .Input("side_input: T") .Output("output: T") - .Attr("T: {float}") + .Attr("T: {float, half, qint8}") + .Attr("Tbias: {float, half}") + .Attr("conv_input_scale: float = 1.0") + .Attr("side_input_scale: float = 0.0") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) - .Attr(GetConvnetDataFormatAttrString()) - .Attr(GetAllActivationModeAttrString()) + .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") + .Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'") + .Attr("activation_mode: {'Relu'} = 'Relu'") .SetShapeFn(shape_inference::FusedConvBiasActivationShape) .Doc(R"doc( - Computes a fused 2-D convolution, adds bias, and applies an activation function - on the output given 4-D `input`, 4-D `filter`, 1-D `bias` tensors and an activation mode. + Computes a fused kernel which implements: 2-D convolution, adds side input, + with separate scaling on convolution and side inputs, then adds bias and + applies the RELU activation function to the result. Supports both float and + qint8 data formats. In the case of qint8, the output is clipped to [0..127]. - input: A 4-D tensor. The dimension order is interpreted according to the value - of `data_format`, see below for details. - filter: A 4-D tensor of shape - `[filter_height, filter_width, in_channels, out_channels]` - bias: 1-D with size of the `out_channels` dimension in filter. - output: A 4-D tensor. The dimension order is determined by the value of - `data_format`, see below for details. - T: The data type for the elements of input, filter, bias, and output Tensors. + conv_input: A tensor with format as specified by `data_format` (see below). + filter: A tensor with format depending on `data_format` as follows: + "NHWC", "NCHW": + `float [ filter_height, filter_width, in_channels, out_channels ]` + "NCHW_VECT_C": + `qint8 [ out_channels, in_channels, filter_height, filter_width ]` + bias: 1-D float tensor with size matching the `out_channels` dimension of + `filter`. + Note: this tensor is still float, even if other inputs are qint8. + side_input: A tensor with format as specified by `data_format` (see below). + This tensor will be ignored and can be [] if side_input_scale == 0. + Otherwise, the size of each dimension must match the `output` tensor. + output: A tensor with format as specified by `data_format` (see below). + The dimension sizes are determined automatically based on other inputs + and attributes. + T: The element data type of `conv_input`, `side_input` and `output` tensors. + Note: must match with the `data_format`. + Tbias: The element data type of `bias`. + conv_input_scale: scalar float value to be multiplied by `conv_input`. + (conceptually.. in reality it is applied after convolution). + side_input_scale: scalar float value to be multiplied by `side_input`. strides: 1-D tensor of length 4. The stride of the sliding window for each dimension of `input`. The dimension order is determined by the value of `data_format`, see below for details. + Note: the stride for batch and channel dimensions must be 1. padding: The type of padding algorithm to use. - data_format: Specify the data format of the input and output data. With the - default format "NHWC", the data is stored in the order of: - [batch, height, width, channels]. - Alternatively, the format could be "NCHW", the data storage order of: - [batch, channels, height, width]. - activation_mode: Specify the activation function to apply to the output tensor - of bias add. Currently only supports "Relu". + data_format: A string specifying the data format of `conv_input`, + `side_input` and `output` tensors with the following options: + "NHWC": `float [ batch, height, width, channels ]` + "NCHW": `float [ batch, channels, height, width ]` + "NCHW_VECT_C": + `qint8 [ batch, channels / 4, height, width, channels % 4 ]` + Note: for "NCHW_VECT_C", `channels` must be a multiple of 4. + filter_format: A string specifying the data format of `filter`, + "HWIO": `float [ kernel_height, kernel_width, input_channels, + output_channels ]` + "OIHW_VECT_I": + `qint8 [ output_channels, input_channels / 4, + kernel_height, kernel_width, input_channels % 4 ]` + activation_mode: The activation applied to the output. + Currently must be "Relu". )doc"); } // namespace tensorflow diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py index 41f986dd07c..8f3f31bad0d 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py @@ -26,62 +26,83 @@ _fused_conv2d_bias_activation_op_so = loader.load_op_library( resource_loader.get_path_to_datafile("_fused_conv2d_bias_activation_op.so")) -def fused_conv2d_bias_activation(input_tensor, - filter_tensor, +# pylint: disable=redefined-builtin +def fused_conv2d_bias_activation(conv_input, + filter, bias, - strides, - padding, - activation_mode, + strides=None, + padding=None, + conv_input_scale=1.0, + side_input_scale=0.0, + side_input=None, + activation_mode="Relu", data_format=None, + filter_format=None, name=None): - """Computes a fused 2-D convolution, adds bias, and applies relu. + """Fused 2D conv, bias and activation with optional side input. - input_tensor: A 4-D tensor. The dimension order is interpreted - according to the value of `data_format`, see below for details. - filter_tensor: A 4-D tensor of shape - `[filter_height, filter_width, in_channels, out_channels]` - bias: 1-D with size of the `out_channels` dimension in filter. - output: A 4-D tensor. The dimension order is determined by the value of - `data_format`, see below for details. - T: The data type for the elements of input, filter, bias, and output - Tensors. - strides: 1-D tensor of length 4. The stride of the sliding window for - each - dimension of `input`. The dimension order is determined by the value - of - `data_format`, see below for details. - padding: The type of padding algorithm to use. - data_format: Specify the data format of the input and output data. With - the - default format "NHWC", the data is stored in the order of: - [batch, height, width, channels]. - Alternatively, the format could be "NCHW", the data storage order of: - [batch, channels, height, width]. - activation_mode: Specify the activation function to apply to the output - tensor - of bias add. Currently only supports "Relu". + Computes a fused 2-D convolution scaled by conv_input_scale, + adds an optional side input scaled by side_input_scale, adds biases, + and applies ReLU. As an equation: + output = ReLU(conv_input_scale * Conv(conv_input, filter) + + side_input_scale * side_input + bias) + Note: In int8 mode, The ReLU will clip the output to the range [0..127]. Args: - input_tensor: A `Tensor`. Must be one of the following types: `float32`. - filter_tensor: A `Tensor`. Must have the same type as `input`. - bias: A `Tensor`. Must have the same type as `input`. - strides: A list of `ints`. + conv_input: A `Tensor` of the format specified by `data_format`. + filter: A `Tensor` whose format depends on `data_format`: + if `data_format` is "NCHW_VECT_C", filter should be "OIHW_VECT_I" + otherwise, it should be "HWIO" format. + bias: A 1-D `Tensor` of type `float32`, and dimensions equal to the + number of output channels. + strides: A list of 4 `ints` specifying convolution strides. + if `data_format` is "NCHW" or "NCHW_VECT_C", the order should be NCHW. + if `data_format` is "NHWC", the order should be NHWC. padding: A `string` from: `"SAME", "VALID"`. - activation_mode: A `string` from: `"Sigmoid", "Relu", "Relu6", "ReluX", - "Tanh", "BandPass"`. - data_format: An optional `string` from: `"NHWC", "NCHW"`. Defaults to - `"NHWC"`. + conv_input_scale: A scalar `float32` that will be multiplied by conv_input. + This is optional and defaults to 1. However it should be set to + specify the quantization scale when `data_format` is "NCHW_VECT_C". + side_input_scale: A scalar `float32` that will be multiplied by side_input. + This is optional and defaults to 0. + side_input: A `Tensor` of the format specified by `data_format`. + This is useful for imlementing ResNet blocks. + activation_mode: (optional) currently must be the default "Relu". + Note that in qint8 mode, it also clips to 127, so acts like ReluX. + data_format: Specifies the data format. + Possible values are: + "NHWC" float [batch, height, width, channels] + "NCHW" float [batch, channels, height, width] + "NCHW_VECT_C" qint8 [batch, channels / 4, height, width, channels % 4] + Defaults to `"NHWC"`. + Performance is worst for `"NHWC"` and best for `"NCHW_VECT_C"`. + filter_format: Specifies the filter format. + Possible values are: + "HWIO" float [kernel_height, kernel_width, input_channels, + output_channels ] + "OIHW" float [output_channels, input_channels, kernel_height, + kernel_width ] + "OIHW_VECT_I" qint8 [ output_channels, input_channels / 4, + kernel_height, kernel_width, input_channels % 4 ] + Defaults to `"HWIO"`. name: A name for the operation (optional). Returns: - A `Tensor`. Has the same type as `input`. + A `Tensor` of the format specified by `data_format`. """ + if strides is None: + strides = [1, 1, 1, 1] + if side_input is None: + side_input = [] return gen_fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( - input=input_tensor, - filter=filter_tensor, - bias=bias, - strides=strides, + conv_input, + filter, + bias, padding=padding, + strides=strides, + conv_input_scale=conv_input_scale, + side_input_scale=side_input_scale, + side_input=side_input, activation_mode=activation_mode, data_format=data_format, + filter_format=filter_format, name=name) diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py index 5d6a2fa3b83..3b8f7d6ed76 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py @@ -19,13 +19,16 @@ from __future__ import division from __future__ import print_function import numpy as np + from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging @@ -484,7 +487,8 @@ class FusedConv2DBiasActivationTest(test.TestCase): with self.test_session() as sess: # Illegal strides. with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "strides in the batch and depth"): + "Convolutional strides are not supported in " + "the batch or depth dimensions."): sess.run( fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( array_ops.placeholder(dtypes.float32), @@ -494,7 +498,8 @@ class FusedConv2DBiasActivationTest(test.TestCase): padding="SAME", activation_mode="Relu")) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "strides in the batch and depth"): + "Convolutional strides are not supported in " + "the batch or depth dimensions."): sess.run( fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( array_ops.placeholder(dtypes.float32), @@ -552,6 +557,286 @@ def GetInceptionFwdTest(input_size, filter_size, stride, padding, return Test +def CalculateCovolvedOutputDim(input_dim, filter_dim, stride, padding_type): + """Calculates the size of an output dimension of a strided convolution. + + Given the sizes of the corresponding dimension of the input and filter shapes, + and the stride and padding_types, calculates the size of the output dimension. + This function can be called separately for each input dimension. + + Args: + input_dim: An `int` specifying the size of the input dimension. + filter_dim: An `int` specifying the size of the filter dimension. + stride: An `int` specifying the step size of the convolution along the + input dimension. + padding_type: either 'VALID' or 'SAME'. + + Returns: + The size of the output dimension. + """ + if padding_type == "VALID": + return (input_dim - filter_dim + stride) // stride + else: # padding_type == 'SAME' + return (input_dim + stride - 1) // stride + + +def NchwVectCToNchw(in_tensor): + # [N, C / 4, H, W, 4] => [N, C / 4, 4, H, W] == [N, C, H, W] + t = array_ops.transpose(in_tensor, [0, 1, 4, 2, 3]) + n = in_tensor.shape.dims[0].value + c = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value + h = in_tensor.shape.dims[2].value + w = in_tensor.shape.dims[3].value + return array_ops.reshape(t, [n, c, h, w]) + + +def OihwVectIToHwio(in_tensor): + # [O, I / 4, H, W, 4] => [O, I / 4, 4, H, W] == [O, I, H, W] + t = array_ops.transpose(in_tensor, [2, 3, 1, 4, 0]) + o = in_tensor.shape.dims[0].value + i = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value + h = in_tensor.shape.dims[2].value + w = in_tensor.shape.dims[3].value + return array_ops.reshape(t, [h, w, i, o]) + + +def NchwToNchwVectC(in_tensor): + n, c, h, w = in_tensor.shape.as_list() + assert c % 4 == 0 + t = array_ops.reshape(in_tensor, [n, c // 4, 4, h, w]) + return array_ops.transpose(t, [0, 1, 3, 4, 2]) + + +def SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel, + padding, strides, side_input_scale, + side_input, biases): + """Simulates the int8 fused 2-D convolution op using separate float ops. + + The arguments and return values have the same format, meanings and + restrictions as the actual op. + Args: + conv_input_scale: A scalar 'float'. + conv_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout. + kernel: A `Tensor` of type `qint8` in OIHW_VECT_I layout. + padding: A `string` from: `"SAME", "VALID"`. + strides: A list of `ints`. + side_input_scale: A scalar 'float'. + side_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout. + biases: A `Tensor` of type `float32` in NCHW layout. + Returns: + A `Tensor` of type `qint8` in NCHW_VECT_C layout. + """ + conv_result = nn_ops.conv2d( + NchwVectCToNchw(gen_array_ops.dequantize(conv_input, -128, 127)), + OihwVectIToHwio(gen_array_ops.dequantize(kernel, -128, 127)), + strides=strides, + padding=padding, + data_format="NCHW") * conv_input_scale + + conv_and_side_inputs = conv_result + side_input_scale * NchwVectCToNchw( + gen_array_ops.dequantize(side_input, -128, 127)) + + logit = nn_ops.bias_add(conv_and_side_inputs, biases, data_format="NCHW") + + result, _, _ = gen_array_ops.quantize_v2( + NchwToNchwVectC(nn_ops.relu(logit)), -128, 127, dtypes.qint8) + return result + + +class FusedConvInt8Tests(test.TestCase): + _test_params = [ + { + "batch_size": 2, + "input_channels": 8, + "output_channels": 16, + "input_height": 8, + "input_width": 8, + "filter_height": 3, + "filter_width": 3, + "vertical_stride": 2, + "horizontal_stride": 2, + "conv_input_scale": 0.002, + "side_input_scale": 0.0, + "bias_scale": 1, + "padding_type": "VALID" + }, + { + "batch_size": 2, + "input_channels": 8, + "output_channels": 16, + "input_height": 8, + "input_width": 8, + "filter_height": 3, + "filter_width": 3, + "vertical_stride": 2, + "horizontal_stride": 2, + "conv_input_scale": 0.002, + "side_input_scale": 0.0, + "bias_scale": 1, + "padding_type": "SAME" + }, + { + "batch_size": 2, + "input_channels": 8, + "output_channels": 16, + "input_height": 8, + "input_width": 8, + "filter_height": 3, + "filter_width": 3, + "vertical_stride": 2, + "horizontal_stride": 2, + "conv_input_scale": 0.002, + "side_input_scale": 0.5, + "bias_scale": 1, + "padding_type": "VALID" + }, + { + "batch_size": 2, + "input_channels": 16, + "output_channels": 16, + "input_height": 9, + "input_width": 9, + "filter_height": 3, + "filter_width": 3, + "vertical_stride": 1, + "horizontal_stride": 1, + "conv_input_scale": 0.001, + "side_input_scale": 0.5, + "bias_scale": 1, + "padding_type": "SAME" + }, + { + "batch_size": 3, + "input_channels": 8, + "output_channels": 8, + "input_height": 9, + "input_width": 9, + "filter_height": 5, + "filter_width": 5, + "vertical_stride": 1, + "horizontal_stride": 1, + "conv_input_scale": 0.001, + "side_input_scale": 0.5, + "bias_scale": 1, + "padding_type": "SAME" + }, + { + "batch_size": 3, + "input_channels": 8, + "output_channels": 8, + "input_height": 9, + "input_width": 9, + "filter_height": 7, + "filter_width": 1, + "vertical_stride": 2, + "horizontal_stride": 1, + "conv_input_scale": 0.002, + "side_input_scale": 0.5, + "bias_scale": 1, + "padding_type": "SAME" + }, + { + "batch_size": 3, + "input_channels": 8, + "output_channels": 8, + "input_height": 9, + "input_width": 9, + "filter_height": 1, + "filter_width": 7, + "vertical_stride": 1, + "horizontal_stride": 1, + "conv_input_scale": 0.002, + "side_input_scale": 0.5, + "bias_scale": 1, + "padding_type": "SAME" + }, + ] + + def runTest(self, test_param): + batch_size = test_param["batch_size"] + input_channels = test_param["input_channels"] + output_channels = test_param["output_channels"] + input_height = test_param["input_height"] + input_width = test_param["input_width"] + filter_height = test_param["filter_height"] + filter_width = test_param["filter_width"] + vertical_stride = test_param["vertical_stride"] + horizontal_stride = test_param["horizontal_stride"] + conv_input_scale = test_param["conv_input_scale"] + side_input_scale = test_param["side_input_scale"] + bias_scale = test_param["bias_scale"] + padding_type = test_param["padding_type"] + + conv_input, _, _ = gen_array_ops.quantize_v2( + random_ops.random_uniform( + [batch_size, input_channels // 4, input_height, input_width, 4], + minval=-0.0, + maxval=1.0, + dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8) + + kernel, _, _ = gen_array_ops.quantize_v2( + random_ops.random_uniform( + [ + output_channels, input_channels // 4, filter_height, + filter_width, 4 + ], + minval=-1.0, + maxval=1.0, + dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8) + + output_height = CalculateCovolvedOutputDim(input_height, filter_height, + vertical_stride, padding_type) + output_width = CalculateCovolvedOutputDim(input_width, filter_width, + horizontal_stride, padding_type) + print("output_height=", output_height, ", output_width=", output_width) + + side_input, _, _ = gen_array_ops.quantize_v2( + random_ops.random_uniform( + [batch_size, output_channels // 4, output_height, output_width, 4], + minval=0.0, + maxval=1.0, + dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8) + + biases = random_ops.random_uniform( + [output_channels], + minval=-10 * bias_scale, + maxval=20 * bias_scale, + dtype=dtypes.float32) + + strides = [1, 1, vertical_stride, horizontal_stride] + + actual = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + conv_input, + kernel, + biases, + strides=strides, + padding=padding_type, + conv_input_scale=conv_input_scale, + side_input_scale=side_input_scale, + side_input=side_input, + data_format="NCHW_VECT_C", + filter_format="OIHW_VECT_I") + + expected = SimulateFusedConv2dBiasActivationInt8( + conv_input_scale, conv_input, kernel, padding_type, strides, + side_input_scale, side_input, biases) + + with self.test_session(use_gpu=True) as sess: + actual_y, expected_y = sess.run([actual, expected]) + print("actual_y = ", actual_y) + print("expected_y = ", expected_y) + self.assertTrue(np.array_equal(actual_y, expected_y)) + + def testFusedConvInt8(self): + if not test.is_gpu_available( + cuda_only=True, min_cuda_compute_capability=(6, 1)): + tf_logging.info("int8 test skipped because not run with --config=cuda or " + "no GPUs with compute capability >= 6.1 are available.") + return + for test_param in self._test_params: + self.runTest(test_param) + + if __name__ == "__main__": for index, (input_size_, filter_size_, output_size_, stride_, padding_) in enumerate(GetShrunkInceptionShapes()): diff --git a/tensorflow/contrib/keras/BUILD b/tensorflow/contrib/keras/BUILD index 26f0e415180..7e0019ce4ad 100644 --- a/tensorflow/contrib/keras/BUILD +++ b/tensorflow/contrib/keras/BUILD @@ -1,5 +1,6 @@ # Description: # Contains the Keras API (internal TensorFlow version). +# Note that tf.contrib.keras has been deprecated in favor of tf.keras. licenses(["notice"]) # Apache 2.0 @@ -7,9 +8,6 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) -load("//tensorflow:tensorflow.bzl", "cuda_py_test") -load("//tensorflow:tensorflow.bzl", "py_test") - py_library( name = "keras", srcs = [ @@ -48,641 +46,10 @@ py_library( "api/keras/utils/__init__.py", "api/keras/wrappers/__init__.py", "api/keras/wrappers/scikit_learn/__init__.py", - "python/keras/__init__.py", - "python/keras/activations.py", - "python/keras/applications/__init__.py", - "python/keras/applications/imagenet_utils.py", - "python/keras/applications/inception_v3.py", - "python/keras/applications/mobilenet.py", - "python/keras/applications/resnet50.py", - "python/keras/applications/vgg16.py", - "python/keras/applications/vgg19.py", - "python/keras/applications/xception.py", - "python/keras/backend.py", - "python/keras/callbacks.py", - "python/keras/constraints.py", - "python/keras/datasets/__init__.py", - "python/keras/datasets/boston_housing.py", - "python/keras/datasets/cifar.py", - "python/keras/datasets/cifar10.py", - "python/keras/datasets/cifar100.py", - "python/keras/datasets/imdb.py", - "python/keras/datasets/mnist.py", - "python/keras/datasets/reuters.py", - "python/keras/engine/__init__.py", - "python/keras/engine/topology.py", - "python/keras/engine/training.py", - "python/keras/initializers.py", - "python/keras/layers/__init__.py", - "python/keras/layers/advanced_activations.py", - "python/keras/layers/convolutional.py", - "python/keras/layers/convolutional_recurrent.py", - "python/keras/layers/core.py", - "python/keras/layers/embeddings.py", - "python/keras/layers/local.py", - "python/keras/layers/merge.py", - "python/keras/layers/noise.py", - "python/keras/layers/normalization.py", - "python/keras/layers/pooling.py", - "python/keras/layers/recurrent.py", - "python/keras/layers/serialization.py", - "python/keras/layers/wrappers.py", - "python/keras/losses.py", - "python/keras/metrics.py", - "python/keras/models.py", - "python/keras/optimizers.py", - "python/keras/preprocessing/__init__.py", - "python/keras/preprocessing/image.py", - "python/keras/preprocessing/sequence.py", - "python/keras/preprocessing/text.py", - "python/keras/regularizers.py", - "python/keras/testing_utils.py", - "python/keras/utils/__init__.py", - "python/keras/utils/conv_utils.py", - "python/keras/utils/data_utils.py", - "python/keras/utils/generic_utils.py", - "python/keras/utils/io_utils.py", - "python/keras/utils/layer_utils.py", - "python/keras/utils/np_utils.py", - "python/keras/utils/vis_utils.py", - "python/keras/wrappers/__init__.py", - "python/keras/wrappers/scikit_learn.py", ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/tensorboard:projector", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:client", - "//tensorflow/python:clip_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:ctc_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework", - "//tensorflow/python:framework_ops", - "//tensorflow/python:functional_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:image_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:layers_base", - "//tensorflow/python:logging_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn", - "//tensorflow/python:platform", - "//tensorflow/python:random_ops", - "//tensorflow/python:sparse_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:tensor_array_grad", - "//tensorflow/python:tensor_array_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "@six_archive//:six", - ], -) - -py_test( - name = "integration_test", - size = "medium", - srcs = ["python/keras/integration_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//tensorflow/python:layers", - "//tensorflow/python:nn", - "//third_party/py/numpy", - ], -) - -py_test( - name = "activations_test", - size = "small", - srcs = ["python/keras/activations_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "constraints_test", - size = "small", - srcs = ["python/keras/constraints_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "initializers_test", - size = "small", - srcs = ["python/keras/initializers_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//tensorflow/python:init_ops", - "//third_party/py/numpy", - ], -) - -py_test( - name = "regularizers_test", - size = "small", - srcs = ["python/keras/regularizers_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "optimizers_test", - size = "medium", - srcs = ["python/keras/optimizers_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//tensorflow/python:training", - "//third_party/py/numpy", - ], -) - -py_test( - name = "losses_test", - size = "small", - srcs = ["python/keras/losses_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "metrics_test", - size = "small", - srcs = ["python/keras/metrics_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "inception_v3_test", - size = "medium", - srcs = ["python/keras/applications/inception_v3_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "mobilenet_test", - size = "medium", - srcs = ["python/keras/applications/mobilenet_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "resnet50_test", - size = "small", - srcs = ["python/keras/applications/resnet50_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "vgg16_test", - size = "small", - srcs = ["python/keras/applications/vgg16_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "vgg19_test", - size = "small", - srcs = ["python/keras/applications/vgg19_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "xception_test", - size = "medium", - srcs = ["python/keras/applications/xception_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "advanced_activations_test", - size = "small", - srcs = ["python/keras/layers/advanced_activations_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "convolutional_recurrent_test", - size = "medium", - srcs = ["python/keras/layers/convolutional_recurrent_test.py"], - shard_count = 2, - srcs_version = "PY2AND3", - tags = ["noasan"], # times out b/63678675 - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "convolutional_test", - size = "medium", - srcs = ["python/keras/layers/convolutional_test.py"], - srcs_version = "PY2AND3", - tags = [ - "manual", - "noasan", # times out b/63678675 - "notsan", - ], - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "pooling_test", - size = "small", - srcs = ["python/keras/layers/pooling_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "core_test", - size = "small", - srcs = ["python/keras/layers/core_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "embeddings_test", - size = "small", - srcs = ["python/keras/layers/embeddings_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "local_test", - size = "medium", - srcs = ["python/keras/layers/local_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "merge_test", - size = "small", - srcs = ["python/keras/layers/merge_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "noise_test", - size = "small", - srcs = ["python/keras/layers/noise_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "normalization_test", - size = "small", - srcs = ["python/keras/layers/normalization_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "simplernn_test", - size = "medium", - srcs = ["python/keras/layers/simplernn_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "gru_test", - size = "medium", - srcs = ["python/keras/layers/gru_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], # http://b/62136390 - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "lstm_test", - size = "medium", - srcs = ["python/keras/layers/lstm_test.py"], - srcs_version = "PY2AND3", - tags = [ - "noasan", # times out b/63678675 - "notsan", # http://b/62189182 - ], - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "serialization_test", - size = "small", - srcs = ["python/keras/layers/serialization_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "wrappers_test", - size = "small", - srcs = ["python/keras/layers/wrappers_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "scikit_learn_test", - size = "small", - srcs = ["python/keras/wrappers/scikit_learn_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "data_utils_test", - size = "small", - srcs = ["python/keras/utils/data_utils_test.py"], - srcs_version = "PY2AND3", - tags = [ - "noasan", # times out - "notsan", - ], - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "generic_utils_test", - size = "small", - srcs = ["python/keras/utils/generic_utils_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "io_utils_test", - size = "small", - srcs = ["python/keras/utils/io_utils_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "imagenet_utils_test", - size = "small", - srcs = ["python/keras/applications/imagenet_utils_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "image_test", - size = "medium", - srcs = ["python/keras/preprocessing/image_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "sequence_test", - size = "small", - srcs = ["python/keras/preprocessing/sequence_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "text_test", - size = "small", - srcs = ["python/keras/preprocessing/text_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "callbacks_test", - size = "medium", - srcs = ["python/keras/callbacks_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "training_test", - size = "medium", - srcs = ["python/keras/engine/training_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "topology_test", - size = "small", - srcs = ["python/keras/engine/topology_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//third_party/py/numpy", - ], -) - -py_test( - name = "models_test", - size = "small", - srcs = ["python/keras/models_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//tensorflow/python:training", - "//third_party/py/numpy", - ], -) - -py_test( - name = "backend_test", - size = "small", - srcs = ["python/keras/backend_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//tensorflow/python:util", - "//third_party/py/numpy", - ], -) - -py_library( - name = "testing_utils", - srcs = [ - "python/keras/testing_utils.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:util", - "//third_party/py/numpy", + "//tensorflow/python/keras", ], ) diff --git a/tensorflow/contrib/keras/README.md b/tensorflow/contrib/keras/README.md index db2556fe422..de4c81268d5 100644 --- a/tensorflow/contrib/keras/README.md +++ b/tensorflow/contrib/keras/README.md @@ -1,3 +1,6 @@ +NOTE: THE `tensorflow.contrib.keras` MODULE HAS BEEN DEPRECATED. +USE INSTEAD `tensorflow.keras`, PART OF CORE TENSORFLOW. + Keras is an object-oriented API for defining and training neural networks. This module contains a pure-TensorFlow implementation of the Keras API, diff --git a/tensorflow/contrib/keras/api/keras/activations/__init__.py b/tensorflow/contrib/keras/api/keras/activations/__init__.py index af6f249e71c..d04838c218d 100644 --- a/tensorflow/contrib/keras/api/keras/activations/__init__.py +++ b/tensorflow/contrib/keras/api/keras/activations/__init__.py @@ -19,22 +19,22 @@ from __future__ import division from __future__ import print_function # Activation functions. -from tensorflow.contrib.keras.python.keras.activations import elu -from tensorflow.contrib.keras.python.keras.activations import hard_sigmoid -from tensorflow.contrib.keras.python.keras.activations import linear -from tensorflow.contrib.keras.python.keras.activations import relu -from tensorflow.contrib.keras.python.keras.activations import selu -from tensorflow.contrib.keras.python.keras.activations import sigmoid -from tensorflow.contrib.keras.python.keras.activations import softmax -from tensorflow.contrib.keras.python.keras.activations import softplus -from tensorflow.contrib.keras.python.keras.activations import softsign -from tensorflow.contrib.keras.python.keras.activations import tanh +from tensorflow.python.keras._impl.keras.activations import elu +from tensorflow.python.keras._impl.keras.activations import hard_sigmoid +from tensorflow.python.keras._impl.keras.activations import linear +from tensorflow.python.keras._impl.keras.activations import relu +from tensorflow.python.keras._impl.keras.activations import selu +from tensorflow.python.keras._impl.keras.activations import sigmoid +from tensorflow.python.keras._impl.keras.activations import softmax +from tensorflow.python.keras._impl.keras.activations import softplus +from tensorflow.python.keras._impl.keras.activations import softsign +from tensorflow.python.keras._impl.keras.activations import tanh # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.contrib.keras.python.keras.activations import deserialize -from tensorflow.contrib.keras.python.keras.activations import serialize -from tensorflow.contrib.keras.python.keras.activations import get +from tensorflow.python.keras._impl.keras.activations import deserialize +from tensorflow.python.keras._impl.keras.activations import serialize +from tensorflow.python.keras._impl.keras.activations import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py b/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py index d8ca73fb97f..abf8393ae45 100644 --- a/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.applications.inception_v3 import decode_predictions -from tensorflow.contrib.keras.python.keras.applications.inception_v3 import InceptionV3 -from tensorflow.contrib.keras.python.keras.applications.inception_v3 import preprocess_input +from tensorflow.python.keras._impl.keras.applications.inception_v3 import decode_predictions +from tensorflow.python.keras._impl.keras.applications.inception_v3 import InceptionV3 +from tensorflow.python.keras._impl.keras.applications.inception_v3 import preprocess_input del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py b/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py index 594861fb51c..b809e91193b 100644 --- a/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.applications.mobilenet import decode_predictions -from tensorflow.contrib.keras.python.keras.applications.mobilenet import MobileNet -from tensorflow.contrib.keras.python.keras.applications.mobilenet import preprocess_input +from tensorflow.python.keras._impl.keras.applications.mobilenet import decode_predictions +from tensorflow.python.keras._impl.keras.applications.mobilenet import MobileNet +from tensorflow.python.keras._impl.keras.applications.mobilenet import preprocess_input del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py b/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py index e9b25b66d5a..530805d150b 100644 --- a/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.applications.resnet50 import decode_predictions -from tensorflow.contrib.keras.python.keras.applications.resnet50 import preprocess_input -from tensorflow.contrib.keras.python.keras.applications.resnet50 import ResNet50 +from tensorflow.python.keras._impl.keras.applications.resnet50 import decode_predictions +from tensorflow.python.keras._impl.keras.applications.resnet50 import preprocess_input +from tensorflow.python.keras._impl.keras.applications.resnet50 import ResNet50 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py b/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py index 2a1f789cc51..118361604bb 100644 --- a/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.applications.vgg16 import decode_predictions -from tensorflow.contrib.keras.python.keras.applications.vgg16 import preprocess_input -from tensorflow.contrib.keras.python.keras.applications.vgg16 import VGG16 +from tensorflow.python.keras._impl.keras.applications.vgg16 import decode_predictions +from tensorflow.python.keras._impl.keras.applications.vgg16 import preprocess_input +from tensorflow.python.keras._impl.keras.applications.vgg16 import VGG16 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py b/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py index 22b5e7c8e49..cda52628f3c 100644 --- a/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.applications.vgg19 import decode_predictions -from tensorflow.contrib.keras.python.keras.applications.vgg19 import preprocess_input -from tensorflow.contrib.keras.python.keras.applications.vgg19 import VGG19 +from tensorflow.python.keras._impl.keras.applications.vgg19 import decode_predictions +from tensorflow.python.keras._impl.keras.applications.vgg19 import preprocess_input +from tensorflow.python.keras._impl.keras.applications.vgg19 import VGG19 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py b/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py index 23d1b6a0b37..ae9cd9cd18c 100644 --- a/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.applications.xception import decode_predictions -from tensorflow.contrib.keras.python.keras.applications.xception import preprocess_input -from tensorflow.contrib.keras.python.keras.applications.xception import Xception +from tensorflow.python.keras._impl.keras.applications.xception import decode_predictions +from tensorflow.python.keras._impl.keras.applications.xception import preprocess_input +from tensorflow.python.keras._impl.keras.applications.xception import Xception del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/backend/__init__.py b/tensorflow/contrib/keras/api/keras/backend/__init__.py index f3721a8dcb1..10ef5a75852 100644 --- a/tensorflow/contrib/keras/api/keras/backend/__init__.py +++ b/tensorflow/contrib/keras/api/keras/backend/__init__.py @@ -19,144 +19,144 @@ from __future__ import division from __future__ import print_function # pylint: disable=redefined-builtin -from tensorflow.contrib.keras.python.keras.backend import abs -from tensorflow.contrib.keras.python.keras.backend import all -from tensorflow.contrib.keras.python.keras.backend import any -from tensorflow.contrib.keras.python.keras.backend import arange -from tensorflow.contrib.keras.python.keras.backend import argmax -from tensorflow.contrib.keras.python.keras.backend import argmin -from tensorflow.contrib.keras.python.keras.backend import backend -from tensorflow.contrib.keras.python.keras.backend import batch_dot -from tensorflow.contrib.keras.python.keras.backend import batch_flatten -from tensorflow.contrib.keras.python.keras.backend import batch_get_value -from tensorflow.contrib.keras.python.keras.backend import batch_normalization -from tensorflow.contrib.keras.python.keras.backend import batch_set_value -from tensorflow.contrib.keras.python.keras.backend import bias_add -from tensorflow.contrib.keras.python.keras.backend import binary_crossentropy -from tensorflow.contrib.keras.python.keras.backend import cast -from tensorflow.contrib.keras.python.keras.backend import cast_to_floatx -from tensorflow.contrib.keras.python.keras.backend import categorical_crossentropy -from tensorflow.contrib.keras.python.keras.backend import clear_session -from tensorflow.contrib.keras.python.keras.backend import clip -from tensorflow.contrib.keras.python.keras.backend import concatenate -from tensorflow.contrib.keras.python.keras.backend import constant -from tensorflow.contrib.keras.python.keras.backend import conv1d -from tensorflow.contrib.keras.python.keras.backend import conv2d -from tensorflow.contrib.keras.python.keras.backend import conv2d_transpose -from tensorflow.contrib.keras.python.keras.backend import conv3d -from tensorflow.contrib.keras.python.keras.backend import cos -from tensorflow.contrib.keras.python.keras.backend import count_params -from tensorflow.contrib.keras.python.keras.backend import ctc_batch_cost -from tensorflow.contrib.keras.python.keras.backend import ctc_decode -from tensorflow.contrib.keras.python.keras.backend import ctc_label_dense_to_sparse -from tensorflow.contrib.keras.python.keras.backend import dot -from tensorflow.contrib.keras.python.keras.backend import dropout -from tensorflow.contrib.keras.python.keras.backend import dtype -from tensorflow.contrib.keras.python.keras.backend import elu -from tensorflow.contrib.keras.python.keras.backend import epsilon -from tensorflow.contrib.keras.python.keras.backend import equal -from tensorflow.contrib.keras.python.keras.backend import eval -from tensorflow.contrib.keras.python.keras.backend import exp -from tensorflow.contrib.keras.python.keras.backend import expand_dims -from tensorflow.contrib.keras.python.keras.backend import eye -from tensorflow.contrib.keras.python.keras.backend import flatten -from tensorflow.contrib.keras.python.keras.backend import floatx -from tensorflow.contrib.keras.python.keras.backend import foldl -from tensorflow.contrib.keras.python.keras.backend import foldr -from tensorflow.contrib.keras.python.keras.backend import function -from tensorflow.contrib.keras.python.keras.backend import gather -from tensorflow.contrib.keras.python.keras.backend import get_session -from tensorflow.contrib.keras.python.keras.backend import get_uid -from tensorflow.contrib.keras.python.keras.backend import get_value -from tensorflow.contrib.keras.python.keras.backend import gradients -from tensorflow.contrib.keras.python.keras.backend import greater -from tensorflow.contrib.keras.python.keras.backend import greater_equal -from tensorflow.contrib.keras.python.keras.backend import hard_sigmoid -from tensorflow.contrib.keras.python.keras.backend import image_data_format -from tensorflow.contrib.keras.python.keras.backend import in_test_phase -from tensorflow.contrib.keras.python.keras.backend import in_top_k -from tensorflow.contrib.keras.python.keras.backend import in_train_phase -from tensorflow.contrib.keras.python.keras.backend import int_shape -from tensorflow.contrib.keras.python.keras.backend import is_sparse -from tensorflow.contrib.keras.python.keras.backend import l2_normalize -from tensorflow.contrib.keras.python.keras.backend import learning_phase -from tensorflow.contrib.keras.python.keras.backend import less -from tensorflow.contrib.keras.python.keras.backend import less_equal -from tensorflow.contrib.keras.python.keras.backend import log -from tensorflow.contrib.keras.python.keras.backend import manual_variable_initialization -from tensorflow.contrib.keras.python.keras.backend import map_fn -from tensorflow.contrib.keras.python.keras.backend import max -from tensorflow.contrib.keras.python.keras.backend import maximum -from tensorflow.contrib.keras.python.keras.backend import mean -from tensorflow.contrib.keras.python.keras.backend import min -from tensorflow.contrib.keras.python.keras.backend import minimum -from tensorflow.contrib.keras.python.keras.backend import moving_average_update -from tensorflow.contrib.keras.python.keras.backend import name_scope -from tensorflow.contrib.keras.python.keras.backend import ndim -from tensorflow.contrib.keras.python.keras.backend import normalize_batch_in_training -from tensorflow.contrib.keras.python.keras.backend import not_equal -from tensorflow.contrib.keras.python.keras.backend import one_hot -from tensorflow.contrib.keras.python.keras.backend import ones -from tensorflow.contrib.keras.python.keras.backend import ones_like -from tensorflow.contrib.keras.python.keras.backend import permute_dimensions -from tensorflow.contrib.keras.python.keras.backend import placeholder -from tensorflow.contrib.keras.python.keras.backend import pool2d -from tensorflow.contrib.keras.python.keras.backend import pool3d -from tensorflow.contrib.keras.python.keras.backend import pow -from tensorflow.contrib.keras.python.keras.backend import print_tensor -from tensorflow.contrib.keras.python.keras.backend import prod -from tensorflow.contrib.keras.python.keras.backend import random_binomial -from tensorflow.contrib.keras.python.keras.backend import random_normal -from tensorflow.contrib.keras.python.keras.backend import random_normal_variable -from tensorflow.contrib.keras.python.keras.backend import random_uniform -from tensorflow.contrib.keras.python.keras.backend import random_uniform_variable -from tensorflow.contrib.keras.python.keras.backend import relu -from tensorflow.contrib.keras.python.keras.backend import repeat -from tensorflow.contrib.keras.python.keras.backend import repeat_elements -from tensorflow.contrib.keras.python.keras.backend import reset_uids -from tensorflow.contrib.keras.python.keras.backend import reshape -from tensorflow.contrib.keras.python.keras.backend import resize_images -from tensorflow.contrib.keras.python.keras.backend import resize_volumes -from tensorflow.contrib.keras.python.keras.backend import reverse -from tensorflow.contrib.keras.python.keras.backend import rnn -from tensorflow.contrib.keras.python.keras.backend import round -from tensorflow.contrib.keras.python.keras.backend import separable_conv2d -from tensorflow.contrib.keras.python.keras.backend import set_epsilon -from tensorflow.contrib.keras.python.keras.backend import set_floatx -from tensorflow.contrib.keras.python.keras.backend import set_image_data_format -from tensorflow.contrib.keras.python.keras.backend import set_learning_phase -from tensorflow.contrib.keras.python.keras.backend import set_session -from tensorflow.contrib.keras.python.keras.backend import set_value -from tensorflow.contrib.keras.python.keras.backend import shape -from tensorflow.contrib.keras.python.keras.backend import sigmoid -from tensorflow.contrib.keras.python.keras.backend import sign -from tensorflow.contrib.keras.python.keras.backend import sin -from tensorflow.contrib.keras.python.keras.backend import softmax -from tensorflow.contrib.keras.python.keras.backend import softplus -from tensorflow.contrib.keras.python.keras.backend import softsign -from tensorflow.contrib.keras.python.keras.backend import sparse_categorical_crossentropy -from tensorflow.contrib.keras.python.keras.backend import spatial_2d_padding -from tensorflow.contrib.keras.python.keras.backend import spatial_3d_padding -from tensorflow.contrib.keras.python.keras.backend import sqrt -from tensorflow.contrib.keras.python.keras.backend import square -from tensorflow.contrib.keras.python.keras.backend import squeeze -from tensorflow.contrib.keras.python.keras.backend import stack -from tensorflow.contrib.keras.python.keras.backend import std -from tensorflow.contrib.keras.python.keras.backend import stop_gradient -from tensorflow.contrib.keras.python.keras.backend import sum -from tensorflow.contrib.keras.python.keras.backend import switch -from tensorflow.contrib.keras.python.keras.backend import tanh -from tensorflow.contrib.keras.python.keras.backend import temporal_padding -from tensorflow.contrib.keras.python.keras.backend import to_dense -from tensorflow.contrib.keras.python.keras.backend import transpose -from tensorflow.contrib.keras.python.keras.backend import truncated_normal -from tensorflow.contrib.keras.python.keras.backend import update -from tensorflow.contrib.keras.python.keras.backend import update_add -from tensorflow.contrib.keras.python.keras.backend import update_sub -from tensorflow.contrib.keras.python.keras.backend import var -from tensorflow.contrib.keras.python.keras.backend import variable -from tensorflow.contrib.keras.python.keras.backend import zeros -from tensorflow.contrib.keras.python.keras.backend import zeros_like +from tensorflow.python.keras._impl.keras.backend import abs +from tensorflow.python.keras._impl.keras.backend import all +from tensorflow.python.keras._impl.keras.backend import any +from tensorflow.python.keras._impl.keras.backend import arange +from tensorflow.python.keras._impl.keras.backend import argmax +from tensorflow.python.keras._impl.keras.backend import argmin +from tensorflow.python.keras._impl.keras.backend import backend +from tensorflow.python.keras._impl.keras.backend import batch_dot +from tensorflow.python.keras._impl.keras.backend import batch_flatten +from tensorflow.python.keras._impl.keras.backend import batch_get_value +from tensorflow.python.keras._impl.keras.backend import batch_normalization +from tensorflow.python.keras._impl.keras.backend import batch_set_value +from tensorflow.python.keras._impl.keras.backend import bias_add +from tensorflow.python.keras._impl.keras.backend import binary_crossentropy +from tensorflow.python.keras._impl.keras.backend import cast +from tensorflow.python.keras._impl.keras.backend import cast_to_floatx +from tensorflow.python.keras._impl.keras.backend import categorical_crossentropy +from tensorflow.python.keras._impl.keras.backend import clear_session +from tensorflow.python.keras._impl.keras.backend import clip +from tensorflow.python.keras._impl.keras.backend import concatenate +from tensorflow.python.keras._impl.keras.backend import constant +from tensorflow.python.keras._impl.keras.backend import conv1d +from tensorflow.python.keras._impl.keras.backend import conv2d +from tensorflow.python.keras._impl.keras.backend import conv2d_transpose +from tensorflow.python.keras._impl.keras.backend import conv3d +from tensorflow.python.keras._impl.keras.backend import cos +from tensorflow.python.keras._impl.keras.backend import count_params +from tensorflow.python.keras._impl.keras.backend import ctc_batch_cost +from tensorflow.python.keras._impl.keras.backend import ctc_decode +from tensorflow.python.keras._impl.keras.backend import ctc_label_dense_to_sparse +from tensorflow.python.keras._impl.keras.backend import dot +from tensorflow.python.keras._impl.keras.backend import dropout +from tensorflow.python.keras._impl.keras.backend import dtype +from tensorflow.python.keras._impl.keras.backend import elu +from tensorflow.python.keras._impl.keras.backend import epsilon +from tensorflow.python.keras._impl.keras.backend import equal +from tensorflow.python.keras._impl.keras.backend import eval +from tensorflow.python.keras._impl.keras.backend import exp +from tensorflow.python.keras._impl.keras.backend import expand_dims +from tensorflow.python.keras._impl.keras.backend import eye +from tensorflow.python.keras._impl.keras.backend import flatten +from tensorflow.python.keras._impl.keras.backend import floatx +from tensorflow.python.keras._impl.keras.backend import foldl +from tensorflow.python.keras._impl.keras.backend import foldr +from tensorflow.python.keras._impl.keras.backend import function +from tensorflow.python.keras._impl.keras.backend import gather +from tensorflow.python.keras._impl.keras.backend import get_session +from tensorflow.python.keras._impl.keras.backend import get_uid +from tensorflow.python.keras._impl.keras.backend import get_value +from tensorflow.python.keras._impl.keras.backend import gradients +from tensorflow.python.keras._impl.keras.backend import greater +from tensorflow.python.keras._impl.keras.backend import greater_equal +from tensorflow.python.keras._impl.keras.backend import hard_sigmoid +from tensorflow.python.keras._impl.keras.backend import image_data_format +from tensorflow.python.keras._impl.keras.backend import in_test_phase +from tensorflow.python.keras._impl.keras.backend import in_top_k +from tensorflow.python.keras._impl.keras.backend import in_train_phase +from tensorflow.python.keras._impl.keras.backend import int_shape +from tensorflow.python.keras._impl.keras.backend import is_sparse +from tensorflow.python.keras._impl.keras.backend import l2_normalize +from tensorflow.python.keras._impl.keras.backend import learning_phase +from tensorflow.python.keras._impl.keras.backend import less +from tensorflow.python.keras._impl.keras.backend import less_equal +from tensorflow.python.keras._impl.keras.backend import log +from tensorflow.python.keras._impl.keras.backend import manual_variable_initialization +from tensorflow.python.keras._impl.keras.backend import map_fn +from tensorflow.python.keras._impl.keras.backend import max +from tensorflow.python.keras._impl.keras.backend import maximum +from tensorflow.python.keras._impl.keras.backend import mean +from tensorflow.python.keras._impl.keras.backend import min +from tensorflow.python.keras._impl.keras.backend import minimum +from tensorflow.python.keras._impl.keras.backend import moving_average_update +from tensorflow.python.keras._impl.keras.backend import name_scope +from tensorflow.python.keras._impl.keras.backend import ndim +from tensorflow.python.keras._impl.keras.backend import normalize_batch_in_training +from tensorflow.python.keras._impl.keras.backend import not_equal +from tensorflow.python.keras._impl.keras.backend import one_hot +from tensorflow.python.keras._impl.keras.backend import ones +from tensorflow.python.keras._impl.keras.backend import ones_like +from tensorflow.python.keras._impl.keras.backend import permute_dimensions +from tensorflow.python.keras._impl.keras.backend import placeholder +from tensorflow.python.keras._impl.keras.backend import pool2d +from tensorflow.python.keras._impl.keras.backend import pool3d +from tensorflow.python.keras._impl.keras.backend import pow +from tensorflow.python.keras._impl.keras.backend import print_tensor +from tensorflow.python.keras._impl.keras.backend import prod +from tensorflow.python.keras._impl.keras.backend import random_binomial +from tensorflow.python.keras._impl.keras.backend import random_normal +from tensorflow.python.keras._impl.keras.backend import random_normal_variable +from tensorflow.python.keras._impl.keras.backend import random_uniform +from tensorflow.python.keras._impl.keras.backend import random_uniform_variable +from tensorflow.python.keras._impl.keras.backend import relu +from tensorflow.python.keras._impl.keras.backend import repeat +from tensorflow.python.keras._impl.keras.backend import repeat_elements +from tensorflow.python.keras._impl.keras.backend import reset_uids +from tensorflow.python.keras._impl.keras.backend import reshape +from tensorflow.python.keras._impl.keras.backend import resize_images +from tensorflow.python.keras._impl.keras.backend import resize_volumes +from tensorflow.python.keras._impl.keras.backend import reverse +from tensorflow.python.keras._impl.keras.backend import rnn +from tensorflow.python.keras._impl.keras.backend import round +from tensorflow.python.keras._impl.keras.backend import separable_conv2d +from tensorflow.python.keras._impl.keras.backend import set_epsilon +from tensorflow.python.keras._impl.keras.backend import set_floatx +from tensorflow.python.keras._impl.keras.backend import set_image_data_format +from tensorflow.python.keras._impl.keras.backend import set_learning_phase +from tensorflow.python.keras._impl.keras.backend import set_session +from tensorflow.python.keras._impl.keras.backend import set_value +from tensorflow.python.keras._impl.keras.backend import shape +from tensorflow.python.keras._impl.keras.backend import sigmoid +from tensorflow.python.keras._impl.keras.backend import sign +from tensorflow.python.keras._impl.keras.backend import sin +from tensorflow.python.keras._impl.keras.backend import softmax +from tensorflow.python.keras._impl.keras.backend import softplus +from tensorflow.python.keras._impl.keras.backend import softsign +from tensorflow.python.keras._impl.keras.backend import sparse_categorical_crossentropy +from tensorflow.python.keras._impl.keras.backend import spatial_2d_padding +from tensorflow.python.keras._impl.keras.backend import spatial_3d_padding +from tensorflow.python.keras._impl.keras.backend import sqrt +from tensorflow.python.keras._impl.keras.backend import square +from tensorflow.python.keras._impl.keras.backend import squeeze +from tensorflow.python.keras._impl.keras.backend import stack +from tensorflow.python.keras._impl.keras.backend import std +from tensorflow.python.keras._impl.keras.backend import stop_gradient +from tensorflow.python.keras._impl.keras.backend import sum +from tensorflow.python.keras._impl.keras.backend import switch +from tensorflow.python.keras._impl.keras.backend import tanh +from tensorflow.python.keras._impl.keras.backend import temporal_padding +from tensorflow.python.keras._impl.keras.backend import to_dense +from tensorflow.python.keras._impl.keras.backend import transpose +from tensorflow.python.keras._impl.keras.backend import truncated_normal +from tensorflow.python.keras._impl.keras.backend import update +from tensorflow.python.keras._impl.keras.backend import update_add +from tensorflow.python.keras._impl.keras.backend import update_sub +from tensorflow.python.keras._impl.keras.backend import var +from tensorflow.python.keras._impl.keras.backend import variable +from tensorflow.python.keras._impl.keras.backend import zeros +from tensorflow.python.keras._impl.keras.backend import zeros_like del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/callbacks/__init__.py b/tensorflow/contrib/keras/api/keras/callbacks/__init__.py index 3a970748573..2d884790ddb 100644 --- a/tensorflow/contrib/keras/api/keras/callbacks/__init__.py +++ b/tensorflow/contrib/keras/api/keras/callbacks/__init__.py @@ -18,19 +18,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.callbacks import BaseLogger -from tensorflow.contrib.keras.python.keras.callbacks import Callback -from tensorflow.contrib.keras.python.keras.callbacks import CSVLogger -from tensorflow.contrib.keras.python.keras.callbacks import EarlyStopping -from tensorflow.contrib.keras.python.keras.callbacks import History -from tensorflow.contrib.keras.python.keras.callbacks import LambdaCallback -from tensorflow.contrib.keras.python.keras.callbacks import LearningRateScheduler -from tensorflow.contrib.keras.python.keras.callbacks import ModelCheckpoint -from tensorflow.contrib.keras.python.keras.callbacks import ProgbarLogger -from tensorflow.contrib.keras.python.keras.callbacks import ReduceLROnPlateau -from tensorflow.contrib.keras.python.keras.callbacks import RemoteMonitor -from tensorflow.contrib.keras.python.keras.callbacks import TensorBoard -from tensorflow.contrib.keras.python.keras.callbacks import TerminateOnNaN +from tensorflow.python.keras._impl.keras.callbacks import BaseLogger +from tensorflow.python.keras._impl.keras.callbacks import Callback +from tensorflow.python.keras._impl.keras.callbacks import CSVLogger +from tensorflow.python.keras._impl.keras.callbacks import EarlyStopping +from tensorflow.python.keras._impl.keras.callbacks import History +from tensorflow.python.keras._impl.keras.callbacks import LambdaCallback +from tensorflow.python.keras._impl.keras.callbacks import LearningRateScheduler +from tensorflow.python.keras._impl.keras.callbacks import ModelCheckpoint +from tensorflow.python.keras._impl.keras.callbacks import ProgbarLogger +from tensorflow.python.keras._impl.keras.callbacks import ReduceLROnPlateau +from tensorflow.python.keras._impl.keras.callbacks import RemoteMonitor +from tensorflow.python.keras._impl.keras.callbacks import TensorBoard +from tensorflow.python.keras._impl.keras.callbacks import TerminateOnNaN del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/constraints/__init__.py b/tensorflow/contrib/keras/api/keras/constraints/__init__.py index 6b9e3bf46e3..152606d8ebb 100644 --- a/tensorflow/contrib/keras/api/keras/constraints/__init__.py +++ b/tensorflow/contrib/keras/api/keras/constraints/__init__.py @@ -19,21 +19,21 @@ from __future__ import division from __future__ import print_function # Constraints functions / callable classes. -from tensorflow.contrib.keras.python.keras.constraints import Constraint -from tensorflow.contrib.keras.python.keras.constraints import max_norm -from tensorflow.contrib.keras.python.keras.constraints import MaxNorm -from tensorflow.contrib.keras.python.keras.constraints import min_max_norm -from tensorflow.contrib.keras.python.keras.constraints import MinMaxNorm -from tensorflow.contrib.keras.python.keras.constraints import non_neg -from tensorflow.contrib.keras.python.keras.constraints import NonNeg -from tensorflow.contrib.keras.python.keras.constraints import unit_norm -from tensorflow.contrib.keras.python.keras.constraints import UnitNorm +from tensorflow.python.keras._impl.keras.constraints import Constraint +from tensorflow.python.keras._impl.keras.constraints import max_norm +from tensorflow.python.keras._impl.keras.constraints import MaxNorm +from tensorflow.python.keras._impl.keras.constraints import min_max_norm +from tensorflow.python.keras._impl.keras.constraints import MinMaxNorm +from tensorflow.python.keras._impl.keras.constraints import non_neg +from tensorflow.python.keras._impl.keras.constraints import NonNeg +from tensorflow.python.keras._impl.keras.constraints import unit_norm +from tensorflow.python.keras._impl.keras.constraints import UnitNorm # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.contrib.keras.python.keras.constraints import deserialize -from tensorflow.contrib.keras.python.keras.constraints import serialize -from tensorflow.contrib.keras.python.keras.constraints import get +from tensorflow.python.keras._impl.keras.constraints import deserialize +from tensorflow.python.keras._impl.keras.constraints import serialize +from tensorflow.python.keras._impl.keras.constraints import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py index 0bfd3df5401..b5371a03fd5 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.datasets.boston_housing import load_data +from tensorflow.python.keras._impl.keras.datasets.boston_housing import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py index f5fac6982ac..68d3eb789ea 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.datasets.cifar10 import load_data +from tensorflow.python.keras._impl.keras.datasets.cifar10 import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py index a7e69961363..ca937426733 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.datasets.cifar100 import load_data +from tensorflow.python.keras._impl.keras.datasets.cifar100 import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py index f141c8a8e98..1c6396d2d32 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.datasets.imdb import get_word_index -from tensorflow.contrib.keras.python.keras.datasets.imdb import load_data +from tensorflow.python.keras._impl.keras.datasets.imdb import get_word_index +from tensorflow.python.keras._impl.keras.datasets.imdb import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py index 50b74f149c1..364255f3387 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.datasets.mnist import load_data +from tensorflow.python.keras._impl.keras.datasets.mnist import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py index fc7f1235a3a..bb6791a344a 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.datasets.reuters import get_word_index -from tensorflow.contrib.keras.python.keras.datasets.reuters import load_data +from tensorflow.python.keras._impl.keras.datasets.reuters import get_word_index +from tensorflow.python.keras._impl.keras.datasets.reuters import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/initializers/__init__.py b/tensorflow/contrib/keras/api/keras/initializers/__init__.py index 9b58723ed5c..6b1fcfd2d95 100644 --- a/tensorflow/contrib/keras/api/keras/initializers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/initializers/__init__.py @@ -19,30 +19,30 @@ from __future__ import division from __future__ import print_function # Initializer functions / callable classes. -from tensorflow.contrib.keras.python.keras.initializers import Constant -from tensorflow.contrib.keras.python.keras.initializers import Identity -from tensorflow.contrib.keras.python.keras.initializers import Initializer -from tensorflow.contrib.keras.python.keras.initializers import Ones -from tensorflow.contrib.keras.python.keras.initializers import Orthogonal -from tensorflow.contrib.keras.python.keras.initializers import RandomNormal -from tensorflow.contrib.keras.python.keras.initializers import RandomUniform -from tensorflow.contrib.keras.python.keras.initializers import TruncatedNormal -from tensorflow.contrib.keras.python.keras.initializers import VarianceScaling -from tensorflow.contrib.keras.python.keras.initializers import Zeros +from tensorflow.python.keras._impl.keras.initializers import Constant +from tensorflow.python.keras._impl.keras.initializers import Identity +from tensorflow.python.keras._impl.keras.initializers import Initializer +from tensorflow.python.keras._impl.keras.initializers import Ones +from tensorflow.python.keras._impl.keras.initializers import Orthogonal +from tensorflow.python.keras._impl.keras.initializers import RandomNormal +from tensorflow.python.keras._impl.keras.initializers import RandomUniform +from tensorflow.python.keras._impl.keras.initializers import TruncatedNormal +from tensorflow.python.keras._impl.keras.initializers import VarianceScaling +from tensorflow.python.keras._impl.keras.initializers import Zeros # Functional interface. # pylint: disable=g-bad-import-order -from tensorflow.contrib.keras.python.keras.initializers import glorot_normal -from tensorflow.contrib.keras.python.keras.initializers import glorot_uniform -from tensorflow.contrib.keras.python.keras.initializers import he_normal -from tensorflow.contrib.keras.python.keras.initializers import he_uniform -from tensorflow.contrib.keras.python.keras.initializers import lecun_normal -from tensorflow.contrib.keras.python.keras.initializers import lecun_uniform +from tensorflow.python.keras._impl.keras.initializers import glorot_normal +from tensorflow.python.keras._impl.keras.initializers import glorot_uniform +from tensorflow.python.keras._impl.keras.initializers import he_normal +from tensorflow.python.keras._impl.keras.initializers import he_uniform +from tensorflow.python.keras._impl.keras.initializers import lecun_normal +from tensorflow.python.keras._impl.keras.initializers import lecun_uniform # Auxiliary utils. -from tensorflow.contrib.keras.python.keras.initializers import deserialize -from tensorflow.contrib.keras.python.keras.initializers import serialize -from tensorflow.contrib.keras.python.keras.initializers import get +from tensorflow.python.keras._impl.keras.initializers import deserialize +from tensorflow.python.keras._impl.keras.initializers import serialize +from tensorflow.python.keras._impl.keras.initializers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/layers/__init__.py b/tensorflow/contrib/keras/api/keras/layers/__init__.py index aafd1892175..acf0a5e1799 100644 --- a/tensorflow/contrib/keras/api/keras/layers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/layers/__init__.py @@ -20,128 +20,128 @@ from __future__ import print_function # Generic layers. # pylint: disable=g-bad-import-order -from tensorflow.contrib.keras.python.keras.engine import Input -from tensorflow.contrib.keras.python.keras.engine import InputLayer -from tensorflow.contrib.keras.python.keras.engine import InputSpec -from tensorflow.contrib.keras.python.keras.engine import Layer +from tensorflow.python.keras._impl.keras.engine import Input +from tensorflow.python.keras._impl.keras.engine import InputLayer +from tensorflow.python.keras._impl.keras.engine import InputSpec +from tensorflow.python.keras._impl.keras.engine import Layer # Advanced activations. -from tensorflow.contrib.keras.python.keras.layers.advanced_activations import LeakyReLU -from tensorflow.contrib.keras.python.keras.layers.advanced_activations import PReLU -from tensorflow.contrib.keras.python.keras.layers.advanced_activations import ELU -from tensorflow.contrib.keras.python.keras.layers.advanced_activations import ThresholdedReLU +from tensorflow.python.keras._impl.keras.layers.advanced_activations import LeakyReLU +from tensorflow.python.keras._impl.keras.layers.advanced_activations import PReLU +from tensorflow.python.keras._impl.keras.layers.advanced_activations import ELU +from tensorflow.python.keras._impl.keras.layers.advanced_activations import ThresholdedReLU # Convolution layers. -from tensorflow.contrib.keras.python.keras.layers.convolutional import Conv1D -from tensorflow.contrib.keras.python.keras.layers.convolutional import Conv2D -from tensorflow.contrib.keras.python.keras.layers.convolutional import Conv3D -from tensorflow.contrib.keras.python.keras.layers.convolutional import Conv2DTranspose -from tensorflow.contrib.keras.python.keras.layers.convolutional import Conv3DTranspose -from tensorflow.contrib.keras.python.keras.layers.convolutional import SeparableConv2D +from tensorflow.python.keras._impl.keras.layers.convolutional import Conv1D +from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2D +from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3D +from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2DTranspose +from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3DTranspose +from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConv2D # Convolution layer aliases. -from tensorflow.contrib.keras.python.keras.layers.convolutional import Convolution1D -from tensorflow.contrib.keras.python.keras.layers.convolutional import Convolution2D -from tensorflow.contrib.keras.python.keras.layers.convolutional import Convolution3D -from tensorflow.contrib.keras.python.keras.layers.convolutional import Convolution2DTranspose -from tensorflow.contrib.keras.python.keras.layers.convolutional import Convolution3DTranspose -from tensorflow.contrib.keras.python.keras.layers.convolutional import SeparableConvolution2D +from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution1D +from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2D +from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3D +from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2DTranspose +from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3DTranspose +from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution2D # Image processing layers. -from tensorflow.contrib.keras.python.keras.layers.convolutional import UpSampling1D -from tensorflow.contrib.keras.python.keras.layers.convolutional import UpSampling2D -from tensorflow.contrib.keras.python.keras.layers.convolutional import UpSampling3D -from tensorflow.contrib.keras.python.keras.layers.convolutional import ZeroPadding1D -from tensorflow.contrib.keras.python.keras.layers.convolutional import ZeroPadding2D -from tensorflow.contrib.keras.python.keras.layers.convolutional import ZeroPadding3D -from tensorflow.contrib.keras.python.keras.layers.convolutional import Cropping1D -from tensorflow.contrib.keras.python.keras.layers.convolutional import Cropping2D -from tensorflow.contrib.keras.python.keras.layers.convolutional import Cropping3D +from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling1D +from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling2D +from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling3D +from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding1D +from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding2D +from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding3D +from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping1D +from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping2D +from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping3D # Convolutional-recurrent layers. -from tensorflow.contrib.keras.python.keras.layers.convolutional_recurrent import ConvLSTM2D +from tensorflow.python.keras._impl.keras.layers.convolutional_recurrent import ConvLSTM2D # Core layers. -from tensorflow.contrib.keras.python.keras.layers.core import Masking -from tensorflow.contrib.keras.python.keras.layers.core import Dropout -from tensorflow.contrib.keras.python.keras.layers.core import SpatialDropout1D -from tensorflow.contrib.keras.python.keras.layers.core import SpatialDropout2D -from tensorflow.contrib.keras.python.keras.layers.core import SpatialDropout3D -from tensorflow.contrib.keras.python.keras.layers.core import Activation -from tensorflow.contrib.keras.python.keras.layers.core import Reshape -from tensorflow.contrib.keras.python.keras.layers.core import Permute -from tensorflow.contrib.keras.python.keras.layers.core import Flatten -from tensorflow.contrib.keras.python.keras.layers.core import RepeatVector -from tensorflow.contrib.keras.python.keras.layers.core import Lambda -from tensorflow.contrib.keras.python.keras.layers.core import Dense -from tensorflow.contrib.keras.python.keras.layers.core import ActivityRegularization +from tensorflow.python.keras._impl.keras.layers.core import Masking +from tensorflow.python.keras._impl.keras.layers.core import Dropout +from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout1D +from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout2D +from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout3D +from tensorflow.python.keras._impl.keras.layers.core import Activation +from tensorflow.python.keras._impl.keras.layers.core import Reshape +from tensorflow.python.keras._impl.keras.layers.core import Permute +from tensorflow.python.keras._impl.keras.layers.core import Flatten +from tensorflow.python.keras._impl.keras.layers.core import RepeatVector +from tensorflow.python.keras._impl.keras.layers.core import Lambda +from tensorflow.python.keras._impl.keras.layers.core import Dense +from tensorflow.python.keras._impl.keras.layers.core import ActivityRegularization # Embedding layers. -from tensorflow.contrib.keras.python.keras.layers.embeddings import Embedding +from tensorflow.python.keras._impl.keras.layers.embeddings import Embedding # Locally-connected layers. -from tensorflow.contrib.keras.python.keras.layers.local import LocallyConnected1D -from tensorflow.contrib.keras.python.keras.layers.local import LocallyConnected2D +from tensorflow.python.keras._impl.keras.layers.local import LocallyConnected1D +from tensorflow.python.keras._impl.keras.layers.local import LocallyConnected2D # Merge layers. -from tensorflow.contrib.keras.python.keras.layers.merge import Add -from tensorflow.contrib.keras.python.keras.layers.merge import Multiply -from tensorflow.contrib.keras.python.keras.layers.merge import Average -from tensorflow.contrib.keras.python.keras.layers.merge import Maximum -from tensorflow.contrib.keras.python.keras.layers.merge import Concatenate -from tensorflow.contrib.keras.python.keras.layers.merge import Dot -from tensorflow.contrib.keras.python.keras.layers.merge import add -from tensorflow.contrib.keras.python.keras.layers.merge import multiply -from tensorflow.contrib.keras.python.keras.layers.merge import average -from tensorflow.contrib.keras.python.keras.layers.merge import maximum -from tensorflow.contrib.keras.python.keras.layers.merge import concatenate -from tensorflow.contrib.keras.python.keras.layers.merge import dot +from tensorflow.python.keras._impl.keras.layers.merge import Add +from tensorflow.python.keras._impl.keras.layers.merge import Multiply +from tensorflow.python.keras._impl.keras.layers.merge import Average +from tensorflow.python.keras._impl.keras.layers.merge import Maximum +from tensorflow.python.keras._impl.keras.layers.merge import Concatenate +from tensorflow.python.keras._impl.keras.layers.merge import Dot +from tensorflow.python.keras._impl.keras.layers.merge import add +from tensorflow.python.keras._impl.keras.layers.merge import multiply +from tensorflow.python.keras._impl.keras.layers.merge import average +from tensorflow.python.keras._impl.keras.layers.merge import maximum +from tensorflow.python.keras._impl.keras.layers.merge import concatenate +from tensorflow.python.keras._impl.keras.layers.merge import dot # Noise layers. -from tensorflow.contrib.keras.python.keras.layers.noise import AlphaDropout -from tensorflow.contrib.keras.python.keras.layers.noise import GaussianNoise -from tensorflow.contrib.keras.python.keras.layers.noise import GaussianDropout +from tensorflow.python.keras._impl.keras.layers.noise import AlphaDropout +from tensorflow.python.keras._impl.keras.layers.noise import GaussianNoise +from tensorflow.python.keras._impl.keras.layers.noise import GaussianDropout # Normalization layers. -from tensorflow.contrib.keras.python.keras.layers.normalization import BatchNormalization +from tensorflow.python.keras._impl.keras.layers.normalization import BatchNormalization # Pooling layers. -from tensorflow.contrib.keras.python.keras.layers.pooling import MaxPooling1D -from tensorflow.contrib.keras.python.keras.layers.pooling import MaxPooling2D -from tensorflow.contrib.keras.python.keras.layers.pooling import MaxPooling3D -from tensorflow.contrib.keras.python.keras.layers.pooling import AveragePooling1D -from tensorflow.contrib.keras.python.keras.layers.pooling import AveragePooling2D -from tensorflow.contrib.keras.python.keras.layers.pooling import AveragePooling3D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalAveragePooling1D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalAveragePooling2D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalAveragePooling3D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalMaxPooling1D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalMaxPooling2D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalMaxPooling3D +from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling1D +from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling2D +from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling3D +from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling1D +from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling2D +from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling3D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling1D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling2D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling3D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling1D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling2D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling3D # Pooling layer aliases. -from tensorflow.contrib.keras.python.keras.layers.pooling import MaxPool1D -from tensorflow.contrib.keras.python.keras.layers.pooling import MaxPool2D -from tensorflow.contrib.keras.python.keras.layers.pooling import MaxPool3D -from tensorflow.contrib.keras.python.keras.layers.pooling import AvgPool1D -from tensorflow.contrib.keras.python.keras.layers.pooling import AvgPool2D -from tensorflow.contrib.keras.python.keras.layers.pooling import AvgPool3D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalAvgPool1D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalAvgPool2D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalAvgPool3D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalMaxPool1D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalMaxPool2D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalMaxPool3D +from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool1D +from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool2D +from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool3D +from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool1D +from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool2D +from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool3D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool1D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool2D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool3D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool1D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool2D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool3D # Recurrent layers. -from tensorflow.contrib.keras.python.keras.layers.recurrent import SimpleRNN -from tensorflow.contrib.keras.python.keras.layers.recurrent import GRU -from tensorflow.contrib.keras.python.keras.layers.recurrent import LSTM +from tensorflow.python.keras._impl.keras.layers.recurrent import SimpleRNN +from tensorflow.python.keras._impl.keras.layers.recurrent import GRU +from tensorflow.python.keras._impl.keras.layers.recurrent import LSTM # Wrapper functions -from tensorflow.contrib.keras.python.keras.layers.wrappers import Wrapper -from tensorflow.contrib.keras.python.keras.layers.wrappers import Bidirectional -from tensorflow.contrib.keras.python.keras.layers.wrappers import TimeDistributed +from tensorflow.python.keras._impl.keras.layers.wrappers import Wrapper +from tensorflow.python.keras._impl.keras.layers.wrappers import Bidirectional +from tensorflow.python.keras._impl.keras.layers.wrappers import TimeDistributed del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/losses/__init__.py b/tensorflow/contrib/keras/api/keras/losses/__init__.py index 06dd679f9ca..66721b694f5 100644 --- a/tensorflow/contrib/keras/api/keras/losses/__init__.py +++ b/tensorflow/contrib/keras/api/keras/losses/__init__.py @@ -19,26 +19,26 @@ from __future__ import division from __future__ import print_function # Loss functions. -from tensorflow.contrib.keras.python.keras.losses import binary_crossentropy -from tensorflow.contrib.keras.python.keras.losses import categorical_crossentropy -from tensorflow.contrib.keras.python.keras.losses import categorical_hinge -from tensorflow.contrib.keras.python.keras.losses import cosine_proximity -from tensorflow.contrib.keras.python.keras.losses import hinge -from tensorflow.contrib.keras.python.keras.losses import kullback_leibler_divergence -from tensorflow.contrib.keras.python.keras.losses import logcosh -from tensorflow.contrib.keras.python.keras.losses import mean_absolute_error -from tensorflow.contrib.keras.python.keras.losses import mean_absolute_percentage_error -from tensorflow.contrib.keras.python.keras.losses import mean_squared_error -from tensorflow.contrib.keras.python.keras.losses import mean_squared_logarithmic_error -from tensorflow.contrib.keras.python.keras.losses import poisson -from tensorflow.contrib.keras.python.keras.losses import sparse_categorical_crossentropy -from tensorflow.contrib.keras.python.keras.losses import squared_hinge +from tensorflow.python.keras._impl.keras.losses import binary_crossentropy +from tensorflow.python.keras._impl.keras.losses import categorical_crossentropy +from tensorflow.python.keras._impl.keras.losses import categorical_hinge +from tensorflow.python.keras._impl.keras.losses import cosine_proximity +from tensorflow.python.keras._impl.keras.losses import hinge +from tensorflow.python.keras._impl.keras.losses import kullback_leibler_divergence +from tensorflow.python.keras._impl.keras.losses import logcosh +from tensorflow.python.keras._impl.keras.losses import mean_absolute_error +from tensorflow.python.keras._impl.keras.losses import mean_absolute_percentage_error +from tensorflow.python.keras._impl.keras.losses import mean_squared_error +from tensorflow.python.keras._impl.keras.losses import mean_squared_logarithmic_error +from tensorflow.python.keras._impl.keras.losses import poisson +from tensorflow.python.keras._impl.keras.losses import sparse_categorical_crossentropy +from tensorflow.python.keras._impl.keras.losses import squared_hinge # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.contrib.keras.python.keras.losses import deserialize -from tensorflow.contrib.keras.python.keras.losses import serialize -from tensorflow.contrib.keras.python.keras.losses import get +from tensorflow.python.keras._impl.keras.losses import deserialize +from tensorflow.python.keras._impl.keras.losses import serialize +from tensorflow.python.keras._impl.keras.losses import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/metrics/__init__.py b/tensorflow/contrib/keras/api/keras/metrics/__init__.py index 99496edde2d..59faf037bce 100644 --- a/tensorflow/contrib/keras/api/keras/metrics/__init__.py +++ b/tensorflow/contrib/keras/api/keras/metrics/__init__.py @@ -19,28 +19,28 @@ from __future__ import division from __future__ import print_function # Metrics functions. -from tensorflow.contrib.keras.python.keras.metrics import binary_accuracy -from tensorflow.contrib.keras.python.keras.metrics import binary_crossentropy -from tensorflow.contrib.keras.python.keras.metrics import categorical_accuracy -from tensorflow.contrib.keras.python.keras.metrics import categorical_crossentropy -from tensorflow.contrib.keras.python.keras.metrics import cosine_proximity -from tensorflow.contrib.keras.python.keras.metrics import hinge -from tensorflow.contrib.keras.python.keras.metrics import kullback_leibler_divergence -from tensorflow.contrib.keras.python.keras.metrics import mean_absolute_error -from tensorflow.contrib.keras.python.keras.metrics import mean_absolute_percentage_error -from tensorflow.contrib.keras.python.keras.metrics import mean_squared_error -from tensorflow.contrib.keras.python.keras.metrics import mean_squared_logarithmic_error -from tensorflow.contrib.keras.python.keras.metrics import poisson -from tensorflow.contrib.keras.python.keras.metrics import sparse_categorical_crossentropy -from tensorflow.contrib.keras.python.keras.metrics import sparse_top_k_categorical_accuracy -from tensorflow.contrib.keras.python.keras.metrics import squared_hinge -from tensorflow.contrib.keras.python.keras.metrics import top_k_categorical_accuracy +from tensorflow.python.keras._impl.keras.metrics import binary_accuracy +from tensorflow.python.keras._impl.keras.metrics import binary_crossentropy +from tensorflow.python.keras._impl.keras.metrics import categorical_accuracy +from tensorflow.python.keras._impl.keras.metrics import categorical_crossentropy +from tensorflow.python.keras._impl.keras.metrics import cosine_proximity +from tensorflow.python.keras._impl.keras.metrics import hinge +from tensorflow.python.keras._impl.keras.metrics import kullback_leibler_divergence +from tensorflow.python.keras._impl.keras.metrics import mean_absolute_error +from tensorflow.python.keras._impl.keras.metrics import mean_absolute_percentage_error +from tensorflow.python.keras._impl.keras.metrics import mean_squared_error +from tensorflow.python.keras._impl.keras.metrics import mean_squared_logarithmic_error +from tensorflow.python.keras._impl.keras.metrics import poisson +from tensorflow.python.keras._impl.keras.metrics import sparse_categorical_crossentropy +from tensorflow.python.keras._impl.keras.metrics import sparse_top_k_categorical_accuracy +from tensorflow.python.keras._impl.keras.metrics import squared_hinge +from tensorflow.python.keras._impl.keras.metrics import top_k_categorical_accuracy # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.contrib.keras.python.keras.metrics import deserialize -from tensorflow.contrib.keras.python.keras.metrics import serialize -from tensorflow.contrib.keras.python.keras.metrics import get +from tensorflow.python.keras._impl.keras.metrics import deserialize +from tensorflow.python.keras._impl.keras.metrics import serialize +from tensorflow.python.keras._impl.keras.metrics import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/models/__init__.py b/tensorflow/contrib/keras/api/keras/models/__init__.py index 4e5b2a1ed08..2fb4ac0960d 100644 --- a/tensorflow/contrib/keras/api/keras/models/__init__.py +++ b/tensorflow/contrib/keras/api/keras/models/__init__.py @@ -18,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.models import load_model -from tensorflow.contrib.keras.python.keras.models import Model -from tensorflow.contrib.keras.python.keras.models import model_from_config -from tensorflow.contrib.keras.python.keras.models import model_from_json -from tensorflow.contrib.keras.python.keras.models import model_from_yaml -from tensorflow.contrib.keras.python.keras.models import save_model -from tensorflow.contrib.keras.python.keras.models import Sequential +from tensorflow.python.keras._impl.keras.models import load_model +from tensorflow.python.keras._impl.keras.models import Model +from tensorflow.python.keras._impl.keras.models import model_from_config +from tensorflow.python.keras._impl.keras.models import model_from_json +from tensorflow.python.keras._impl.keras.models import model_from_yaml +from tensorflow.python.keras._impl.keras.models import save_model +from tensorflow.python.keras._impl.keras.models import Sequential del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/optimizers/__init__.py b/tensorflow/contrib/keras/api/keras/optimizers/__init__.py index b3531d7933f..44f47bc47f4 100644 --- a/tensorflow/contrib/keras/api/keras/optimizers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/optimizers/__init__.py @@ -19,20 +19,20 @@ from __future__ import division from __future__ import print_function # Optimizer classes. -from tensorflow.contrib.keras.python.keras.optimizers import Adadelta -from tensorflow.contrib.keras.python.keras.optimizers import Adagrad -from tensorflow.contrib.keras.python.keras.optimizers import Adam -from tensorflow.contrib.keras.python.keras.optimizers import Adamax -from tensorflow.contrib.keras.python.keras.optimizers import Nadam -from tensorflow.contrib.keras.python.keras.optimizers import Optimizer -from tensorflow.contrib.keras.python.keras.optimizers import RMSprop -from tensorflow.contrib.keras.python.keras.optimizers import SGD +from tensorflow.python.keras._impl.keras.optimizers import Adadelta +from tensorflow.python.keras._impl.keras.optimizers import Adagrad +from tensorflow.python.keras._impl.keras.optimizers import Adam +from tensorflow.python.keras._impl.keras.optimizers import Adamax +from tensorflow.python.keras._impl.keras.optimizers import Nadam +from tensorflow.python.keras._impl.keras.optimizers import Optimizer +from tensorflow.python.keras._impl.keras.optimizers import RMSprop +from tensorflow.python.keras._impl.keras.optimizers import SGD # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.contrib.keras.python.keras.optimizers import deserialize -from tensorflow.contrib.keras.python.keras.optimizers import serialize -from tensorflow.contrib.keras.python.keras.optimizers import get +from tensorflow.python.keras._impl.keras.optimizers import deserialize +from tensorflow.python.keras._impl.keras.optimizers import serialize +from tensorflow.python.keras._impl.keras.optimizers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py index 18ce1becc29..b96e7675527 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py @@ -18,20 +18,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.preprocessing.image import apply_transform -from tensorflow.contrib.keras.python.keras.preprocessing.image import array_to_img -from tensorflow.contrib.keras.python.keras.preprocessing.image import DirectoryIterator -from tensorflow.contrib.keras.python.keras.preprocessing.image import flip_axis -from tensorflow.contrib.keras.python.keras.preprocessing.image import ImageDataGenerator -from tensorflow.contrib.keras.python.keras.preprocessing.image import img_to_array -from tensorflow.contrib.keras.python.keras.preprocessing.image import Iterator -from tensorflow.contrib.keras.python.keras.preprocessing.image import load_img -from tensorflow.contrib.keras.python.keras.preprocessing.image import NumpyArrayIterator -from tensorflow.contrib.keras.python.keras.preprocessing.image import random_channel_shift -from tensorflow.contrib.keras.python.keras.preprocessing.image import random_rotation -from tensorflow.contrib.keras.python.keras.preprocessing.image import random_shear -from tensorflow.contrib.keras.python.keras.preprocessing.image import random_shift -from tensorflow.contrib.keras.python.keras.preprocessing.image import random_zoom +from tensorflow.python.keras._impl.keras.preprocessing.image import apply_transform +from tensorflow.python.keras._impl.keras.preprocessing.image import array_to_img +from tensorflow.python.keras._impl.keras.preprocessing.image import DirectoryIterator +from tensorflow.python.keras._impl.keras.preprocessing.image import flip_axis +from tensorflow.python.keras._impl.keras.preprocessing.image import ImageDataGenerator +from tensorflow.python.keras._impl.keras.preprocessing.image import img_to_array +from tensorflow.python.keras._impl.keras.preprocessing.image import Iterator +from tensorflow.python.keras._impl.keras.preprocessing.image import load_img +from tensorflow.python.keras._impl.keras.preprocessing.image import NumpyArrayIterator +from tensorflow.python.keras._impl.keras.preprocessing.image import random_channel_shift +from tensorflow.python.keras._impl.keras.preprocessing.image import random_rotation +from tensorflow.python.keras._impl.keras.preprocessing.image import random_shear +from tensorflow.python.keras._impl.keras.preprocessing.image import random_shift +from tensorflow.python.keras._impl.keras.preprocessing.image import random_zoom del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py index 2621e9bf53e..112f6af5e58 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.preprocessing.sequence import make_sampling_table -from tensorflow.contrib.keras.python.keras.preprocessing.sequence import pad_sequences -from tensorflow.contrib.keras.python.keras.preprocessing.sequence import skipgrams +from tensorflow.python.keras._impl.keras.preprocessing.sequence import make_sampling_table +from tensorflow.python.keras._impl.keras.preprocessing.sequence import pad_sequences +from tensorflow.python.keras._impl.keras.preprocessing.sequence import skipgrams del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py index a6b68c3ba68..5bf1a2fb21d 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.preprocessing.text import one_hot -from tensorflow.contrib.keras.python.keras.preprocessing.text import text_to_word_sequence -from tensorflow.contrib.keras.python.keras.preprocessing.text import Tokenizer +from tensorflow.python.keras._impl.keras.preprocessing.text import one_hot +from tensorflow.python.keras._impl.keras.preprocessing.text import text_to_word_sequence +from tensorflow.python.keras._impl.keras.preprocessing.text import Tokenizer del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/regularizers/__init__.py b/tensorflow/contrib/keras/api/keras/regularizers/__init__.py index a3b0062d5c8..3e707ccab57 100644 --- a/tensorflow/contrib/keras/api/keras/regularizers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/regularizers/__init__.py @@ -19,19 +19,19 @@ from __future__ import division from __future__ import print_function # Regularizer functions / callable classes. -from tensorflow.contrib.keras.python.keras.regularizers import L1L2 -from tensorflow.contrib.keras.python.keras.regularizers import Regularizer +from tensorflow.python.keras._impl.keras.regularizers import L1L2 +from tensorflow.python.keras._impl.keras.regularizers import Regularizer # Functional interface. # pylint: disable=g-bad-import-order -from tensorflow.contrib.keras.python.keras.regularizers import l1 -from tensorflow.contrib.keras.python.keras.regularizers import l2 -from tensorflow.contrib.keras.python.keras.regularizers import l1_l2 +from tensorflow.python.keras._impl.keras.regularizers import l1 +from tensorflow.python.keras._impl.keras.regularizers import l2 +from tensorflow.python.keras._impl.keras.regularizers import l1_l2 # Auxiliary utils. -from tensorflow.contrib.keras.python.keras.regularizers import deserialize -from tensorflow.contrib.keras.python.keras.regularizers import serialize -from tensorflow.contrib.keras.python.keras.regularizers import get +from tensorflow.python.keras._impl.keras.regularizers import deserialize +from tensorflow.python.keras._impl.keras.regularizers import serialize +from tensorflow.python.keras._impl.keras.regularizers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/utils/__init__.py b/tensorflow/contrib/keras/api/keras/utils/__init__.py index d6d70f79d5f..a7c2179fe7a 100644 --- a/tensorflow/contrib/keras/api/keras/utils/__init__.py +++ b/tensorflow/contrib/keras/api/keras/utils/__init__.py @@ -18,21 +18,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.utils.data_utils import GeneratorEnqueuer -from tensorflow.contrib.keras.python.keras.utils.data_utils import get_file -from tensorflow.contrib.keras.python.keras.utils.data_utils import Sequence -from tensorflow.contrib.keras.python.keras.utils.data_utils import SequenceEnqueuer -from tensorflow.contrib.keras.python.keras.utils.generic_utils import custom_object_scope -from tensorflow.contrib.keras.python.keras.utils.generic_utils import CustomObjectScope -from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.contrib.keras.python.keras.utils.generic_utils import get_custom_objects -from tensorflow.contrib.keras.python.keras.utils.generic_utils import Progbar -from tensorflow.contrib.keras.python.keras.utils.generic_utils import serialize_keras_object -from tensorflow.contrib.keras.python.keras.utils.io_utils import HDF5Matrix -from tensorflow.contrib.keras.python.keras.utils.layer_utils import convert_all_kernels_in_model -from tensorflow.contrib.keras.python.keras.utils.np_utils import normalize -from tensorflow.contrib.keras.python.keras.utils.np_utils import to_categorical -from tensorflow.contrib.keras.python.keras.utils.vis_utils import plot_model +from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer +from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence +from tensorflow.python.keras._impl.keras.utils.data_utils import SequenceEnqueuer +from tensorflow.python.keras._impl.keras.utils.generic_utils import custom_object_scope +from tensorflow.python.keras._impl.keras.utils.generic_utils import CustomObjectScope +from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras._impl.keras.utils.generic_utils import get_custom_objects +from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar +from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.keras._impl.keras.utils.io_utils import HDF5Matrix +from tensorflow.python.keras._impl.keras.utils.layer_utils import convert_all_kernels_in_model +from tensorflow.python.keras._impl.keras.utils.np_utils import normalize +from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical +from tensorflow.python.keras._impl.keras.utils.vis_utils import plot_model del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py b/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py index ba1d28c5c68..a46f859273e 100644 --- a/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py +++ b/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.wrappers.scikit_learn import KerasClassifier -from tensorflow.contrib.keras.python.keras.wrappers.scikit_learn import KerasRegressor +from tensorflow.python.keras._impl.keras.wrappers.scikit_learn import KerasClassifier +from tensorflow.python.keras._impl.keras.wrappers.scikit_learn import KerasRegressor del absolute_import del division diff --git a/tensorflow/contrib/keras/python/keras/__init__.py b/tensorflow/contrib/keras/python/keras/__init__.py deleted file mode 100644 index a3edb29170d..00000000000 --- a/tensorflow/contrib/keras/python/keras/__init__.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""The Keras API. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.keras.python.keras import activations -from tensorflow.contrib.keras.python.keras import applications -from tensorflow.contrib.keras.python.keras import backend -from tensorflow.contrib.keras.python.keras import callbacks -from tensorflow.contrib.keras.python.keras import constraints -from tensorflow.contrib.keras.python.keras import datasets -from tensorflow.contrib.keras.python.keras import engine -from tensorflow.contrib.keras.python.keras import initializers -from tensorflow.contrib.keras.python.keras import layers -from tensorflow.contrib.keras.python.keras import losses -from tensorflow.contrib.keras.python.keras import metrics -from tensorflow.contrib.keras.python.keras import models -from tensorflow.contrib.keras.python.keras import optimizers -from tensorflow.contrib.keras.python.keras import preprocessing -from tensorflow.contrib.keras.python.keras import regularizers -from tensorflow.contrib.keras.python.keras import utils -from tensorflow.contrib.keras.python.keras import wrappers -from tensorflow.contrib.keras.python.keras.layers import Input - -__version__ = '2.0.8-tf' diff --git a/tensorflow/contrib/keras/python/keras/layers/__init__.py b/tensorflow/contrib/keras/python/keras/layers/__init__.py deleted file mode 100644 index 9a428f31141..00000000000 --- a/tensorflow/contrib/keras/python/keras/layers/__init__.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras layers module. -""" -# pylint: disable=wildcard-import -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.keras.python.keras.engine import Input -from tensorflow.contrib.keras.python.keras.engine import InputLayer -from tensorflow.contrib.keras.python.keras.engine import InputSpec -from tensorflow.contrib.keras.python.keras.engine import Layer -from tensorflow.contrib.keras.python.keras.layers.advanced_activations import * -from tensorflow.contrib.keras.python.keras.layers.convolutional import * -from tensorflow.contrib.keras.python.keras.layers.convolutional_recurrent import * -from tensorflow.contrib.keras.python.keras.layers.core import * -from tensorflow.contrib.keras.python.keras.layers.embeddings import * -from tensorflow.contrib.keras.python.keras.layers.local import * -from tensorflow.contrib.keras.python.keras.layers.merge import * -from tensorflow.contrib.keras.python.keras.layers.noise import * -from tensorflow.contrib.keras.python.keras.layers.normalization import * -from tensorflow.contrib.keras.python.keras.layers.pooling import * -from tensorflow.contrib.keras.python.keras.layers.recurrent import * -from tensorflow.contrib.keras.python.keras.layers.serialization import deserialize -from tensorflow.contrib.keras.python.keras.layers.serialization import serialize -from tensorflow.contrib.keras.python.keras.layers.wrappers import * - diff --git a/tensorflow/contrib/keras/python/keras/utils/__init__.py b/tensorflow/contrib/keras/python/keras/utils/__init__.py deleted file mode 100644 index 3b197653f38..00000000000 --- a/tensorflow/contrib/keras/python/keras/utils/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras utilities. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.keras.python.keras.utils import conv_utils -from tensorflow.contrib.keras.python.keras.utils import data_utils -from tensorflow.contrib.keras.python.keras.utils import generic_utils -from tensorflow.contrib.keras.python.keras.utils import io_utils -from tensorflow.contrib.keras.python.keras.utils import np_utils -from tensorflow.contrib.keras.python.keras.utils.data_utils import GeneratorEnqueuer -from tensorflow.contrib.keras.python.keras.utils.data_utils import get_file -from tensorflow.contrib.keras.python.keras.utils.data_utils import OrderedEnqueuer -from tensorflow.contrib.keras.python.keras.utils.data_utils import Sequence -from tensorflow.contrib.keras.python.keras.utils.generic_utils import custom_object_scope -from tensorflow.contrib.keras.python.keras.utils.generic_utils import CustomObjectScope -from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.contrib.keras.python.keras.utils.generic_utils import get_custom_objects -from tensorflow.contrib.keras.python.keras.utils.generic_utils import Progbar -from tensorflow.contrib.keras.python.keras.utils.generic_utils import serialize_keras_object -from tensorflow.contrib.keras.python.keras.utils.io_utils import HDF5Matrix -from tensorflow.contrib.keras.python.keras.utils.layer_utils import convert_all_kernels_in_model -from tensorflow.contrib.keras.python.keras.utils.np_utils import normalize -from tensorflow.contrib.keras.python.keras.utils.np_utils import to_categorical -from tensorflow.contrib.keras.python.keras.utils.vis_utils import plot_model - - -# Globally-importable utils. diff --git a/tensorflow/contrib/keras/python/keras/utils/io_utils_test.py b/tensorflow/contrib/keras/python/keras/utils/io_utils_test.py deleted file mode 100644 index f6820ee0394..00000000000 --- a/tensorflow/contrib/keras/python/keras/utils/io_utils_test.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for io_utils.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import shutil - -import numpy as np - -from tensorflow.contrib.keras.python import keras -from tensorflow.python.platform import test - -try: - import h5py # pylint:disable=g-import-not-at-top -except ImportError: - h5py = None - - -def create_dataset(h5_path='test.h5'): - x = np.random.randn(200, 10).astype('float32') - y = np.random.randint(0, 2, size=(200, 1)) - f = h5py.File(h5_path, 'w') - # Creating dataset to store features - x_dset = f.create_dataset('my_data', (200, 10), dtype='f') - x_dset[:] = x - # Creating dataset to store labels - y_dset = f.create_dataset('my_labels', (200, 1), dtype='i') - y_dset[:] = y - f.close() - - -class TestIOUtils(test.TestCase): - - def test_HDF5Matrix(self): - if h5py is None: - return - - temp_dir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, temp_dir) - - h5_path = os.path.join(temp_dir, 'test.h5') - create_dataset(h5_path) - - with self.test_session(): - # Instantiating HDF5Matrix for the training set, - # which is a slice of the first 150 elements - x_train = keras.utils.io_utils.HDF5Matrix( - h5_path, 'my_data', start=0, end=150) - y_train = keras.utils.io_utils.HDF5Matrix( - h5_path, 'my_labels', start=0, end=150) - - # Likewise for the test set - x_test = keras.utils.io_utils.HDF5Matrix( - h5_path, 'my_data', start=150, end=200) - y_test = keras.utils.io_utils.HDF5Matrix( - h5_path, 'my_labels', start=150, end=200) - - # HDF5Matrix behave more or less like Numpy matrices - # with regard to indexing - self.assertEqual(y_train.shape, (150, 1)) - # But they don't support negative indices, so don't try print(x_train[-1]) - - self.assertEqual(y_train.dtype, np.dtype('i')) - self.assertEqual(y_train.ndim, 2) - self.assertEqual(y_train.size, 150) - - model = keras.models.Sequential() - model.add(keras.layers.Dense(64, input_shape=(10,), activation='relu')) - model.add(keras.layers.Dense(1, activation='sigmoid')) - model.compile(loss='binary_crossentropy', optimizer='sgd') - - # Note: you have to use shuffle='batch' or False with HDF5Matrix - model.fit(x_train, y_train, batch_size=32, shuffle='batch', verbose=False) - # test that evalutation and prediction - # don't crash and return reasonable results - out_pred = model.predict(x_test, batch_size=32, verbose=False) - out_eval = model.evaluate(x_test, y_test, batch_size=32, verbose=False) - - self.assertEqual(out_pred.shape, (50, 1)) - self.assertEqual(out_eval.shape, ()) - self.assertGreater(out_eval, 0) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index db3be9a991b..d35b5556fc8 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -411,33 +411,6 @@ py_test( ], ) -py_test( - name = "dnn_linear_combined_benchmark_test", - size = "medium", - srcs = ["python/learn/estimators/dnn_linear_combined_benchmark_test.py"], - srcs_version = "PY2AND3", - tags = [ - "guitar", - "local", - "manual", - "notap", - ], - visibility = [ - "//learning/brain/google/guitar:__subpackages__", - "//tensorflow:__subpackages__", - ], - deps = [ - ":learn", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/contrib/learn/python/learn/datasets", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", - ], -) - py_test( name = "kmeans_test", size = "medium", @@ -459,32 +432,6 @@ py_test( ], ) -py_test( - name = "dnn_benchmark_test", - size = "medium", - srcs = ["python/learn/estimators/dnn_benchmark_test.py"], - srcs_version = "PY2AND3", - tags = [ - "guitar", - "local", - "manual", - "notap", - ], - visibility = [ - "//learning/brain/google/guitar:__subpackages__", - "//tensorflow:__subpackages__", - ], - deps = [ - ":learn", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", - "//third_party/py/numpy", - ], -) - py_test( name = "dynamic_rnn_estimator_test", size = "medium", diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_benchmark_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_benchmark_test.py deleted file mode 100644 index 86b3eee6ad1..00000000000 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_benchmark_test.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Regression test for DNNEstimator.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools -import numpy as np -from tensorflow.contrib.layers.python.layers import feature_column -from tensorflow.contrib.learn.python.learn.estimators import dnn -from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils -from tensorflow.contrib.learn.python.learn.estimators import run_config -from tensorflow.contrib.learn.python.learn.estimators import test_data -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.platform import test -from tensorflow.python.training import input as input_lib - - -_METRIC_KEYS = { - 'accuracy', - 'auc', - 'accuracy/threshold_0.500000_mean', - 'loss', - 'precision/positive_threshold_0.500000_mean', - 'recall/positive_threshold_0.500000_mean', -} - - -class DNNClassifierBenchmark(test.Benchmark): - - def _report_metrics(self, metrics): - self.report_benchmark( - iters=metrics['global_step'], - extras={k: v - for k, v in metrics.items() if k in _METRIC_KEYS}) - - def _report_predictions(self, - benchmark_name_override, - classifier, - input_fn, - iters, - n_examples, - n_classes, - expected_probabilities=None, - expected_classes=None): - probabilities = classifier.predict_proba( - input_fn=input_fn, as_iterable=False) - if expected_probabilities is not None: - np.testing.assert_allclose( - expected_probabilities, tuple(probabilities), atol=0.2) - - classes = classifier.predict(input_fn=input_fn, as_iterable=False) - if expected_classes is not None: - np.testing.assert_array_equal(expected_classes, classes) - - self.report_benchmark( - iters=iters, - extras={ - 'inference.example%d_class%d_prob' % (i, j): probabilities[i][j] - for j in range(n_classes) for i in range(n_examples) - }.update({ - 'inference.example%d_class' % i: classes[i] - for i in range(n_examples) - }), - name=benchmark_name_override) - - def benchmarkLogisticMatrixData(self): - classifier = dnn.DNNClassifier( - feature_columns=(feature_column.real_valued_column( - 'feature', dimension=4),), - hidden_units=(3, 3), - config=run_config.RunConfig(tf_random_seed=1)) - input_fn = test_data.iris_input_logistic_fn - steps = 400 - metrics = classifier.fit(input_fn=input_fn, steps=steps).evaluate( - input_fn=input_fn, steps=1) - estimator_test_utils.assert_in_range(steps, steps + 5, 'global_step', - metrics) - estimator_test_utils.assert_in_range(0.9, 1.0, 'accuracy', metrics) - estimator_test_utils.assert_in_range(0.0, 0.3, 'loss', metrics) - - self._report_metrics(metrics) - - def benchmarkLogisticMatrixDataLabels1D(self): - - def _input_fn(): - iris = test_data.prepare_iris_data_for_logistic_regression() - return { - 'feature': constant_op.constant( - iris.data, dtype=dtypes.float32) - }, constant_op.constant( - iris.target, shape=(100,), dtype=dtypes.int32) - - classifier = dnn.DNNClassifier( - feature_columns=(feature_column.real_valued_column( - 'feature', dimension=4),), - hidden_units=(3, 3), - config=run_config.RunConfig(tf_random_seed=1)) - steps = 1000 - metrics = classifier.fit(input_fn=_input_fn, steps=steps).evaluate( - input_fn=_input_fn, steps=1) - estimator_test_utils.assert_in_range(steps, steps + 5, 'global_step', - metrics) - estimator_test_utils.assert_in_range(0.9, 1.0, 'accuracy', metrics) - - self._report_metrics(metrics) - - def benchmarkLogisticNpMatrixData(self): - classifier = dnn.DNNClassifier( - feature_columns=(feature_column.real_valued_column( - '', dimension=4),), - hidden_units=(3, 3), - config=run_config.RunConfig(tf_random_seed=1)) - iris = test_data.prepare_iris_data_for_logistic_regression() - train_x = iris.data - train_y = iris.target - steps = 100 - metrics = classifier.fit(x=train_x, y=train_y, steps=steps).evaluate( - x=train_x, y=train_y, steps=1) - estimator_test_utils.assert_in_range(steps, steps + 5, 'global_step', - metrics) - estimator_test_utils.assert_in_range(0.8, 1.0, 'accuracy', metrics) - - self._report_metrics(metrics) - - def benchmarkLogisticTensorData(self): - - def _input_fn(num_epochs=None): - features = { - 'age': - input_lib.limit_epochs( - constant_op.constant(((.8,), (0.2,), (.1,))), - num_epochs=num_epochs), - 'language': - sparse_tensor.SparseTensor( - values=input_lib.limit_epochs( - ('en', 'fr', 'zh'), num_epochs=num_epochs), - indices=((0, 0), (0, 1), (2, 0)), - dense_shape=(3, 2)) - } - return features, constant_op.constant( - ((1,), (0,), (0,)), dtype=dtypes.int32) - - lang_column = feature_column.sparse_column_with_hash_bucket( - 'language', hash_bucket_size=20) - classifier = dnn.DNNClassifier( - feature_columns=(feature_column.embedding_column( - lang_column, dimension=1), - feature_column.real_valued_column('age')), - hidden_units=(3, 3), - config=run_config.RunConfig(tf_random_seed=1)) - steps = 100 - metrics = classifier.fit(input_fn=_input_fn, steps=steps).evaluate( - input_fn=_input_fn, steps=1) - estimator_test_utils.assert_in_range(steps, steps + 5, 'global_step', - metrics) - estimator_test_utils.assert_in_range(0.9, 1.0, 'accuracy', metrics) - estimator_test_utils.assert_in_range(0.0, 0.3, 'loss', metrics) - - self._report_metrics(metrics) - self._report_predictions( - classifier=classifier, - input_fn=functools.partial(_input_fn, num_epochs=1), - iters=metrics['global_step'], - n_examples=3, - n_classes=2, - expected_classes=(1, 0, 0), - benchmark_name_override=( - 'DNNClassifierBenchmark.benchmarkLogisticTensorData_predictions')) - - def benchmarkLogisticFloatLabel(self): - - def _input_fn(num_epochs=None): - features = { - 'age': - input_lib.limit_epochs( - constant_op.constant(((50,), (20,), (10,))), - num_epochs=num_epochs), - 'language': - sparse_tensor.SparseTensor( - values=input_lib.limit_epochs( - ('en', 'fr', 'zh'), num_epochs=num_epochs), - indices=((0, 0), (0, 1), (2, 0)), - dense_shape=(3, 2)) - } - return features, constant_op.constant( - ((0.8,), (0.,), (0.2,)), dtype=dtypes.float32) - - lang_column = feature_column.sparse_column_with_hash_bucket( - 'language', hash_bucket_size=20) - n_classes = 2 - classifier = dnn.DNNClassifier( - n_classes=n_classes, - feature_columns=(feature_column.embedding_column( - lang_column, dimension=1), - feature_column.real_valued_column('age')), - hidden_units=(3, 3), - config=run_config.RunConfig(tf_random_seed=1)) - steps = 1000 - metrics = classifier.fit(input_fn=_input_fn, steps=steps).evaluate( - input_fn=_input_fn, steps=1) - estimator_test_utils.assert_in_range(steps, steps + 5, 'global_step', - metrics) - - # Prediction probabilities mirror the labels column, which proves that the - # classifier learns from float input. - self._report_metrics(metrics) - self._report_predictions( - classifier=classifier, - input_fn=functools.partial(_input_fn, num_epochs=1), - iters=metrics['global_step'], - n_examples=3, - n_classes=n_classes, - expected_probabilities=((0.2, 0.8), (1., 0.), (0.8, 0.2)), - expected_classes=(1, 0, 0), - benchmark_name_override=( - 'DNNClassifierBenchmark.benchmarkLogisticFloatLabel_predictions')) - - def benchmarkMultiClassMatrixData(self): - """Tests multi-class classification using matrix data as input.""" - classifier = dnn.DNNClassifier( - n_classes=3, - feature_columns=(feature_column.real_valued_column( - 'feature', dimension=4),), - hidden_units=(3, 3), - config=run_config.RunConfig(tf_random_seed=1)) - - input_fn = test_data.iris_input_multiclass_fn - steps = 500 - metrics = classifier.fit(input_fn=input_fn, steps=steps).evaluate( - input_fn=input_fn, steps=1) - estimator_test_utils.assert_in_range(steps, steps + 5, 'global_step', - metrics) - estimator_test_utils.assert_in_range(0.9, 1.0, 'accuracy', metrics) - estimator_test_utils.assert_in_range(0.0, 0.4, 'loss', metrics) - - self._report_metrics(metrics) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_benchmark_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_benchmark_test.py deleted file mode 100644 index 98b7c7e95c5..00000000000 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_benchmark_test.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Regression test for DNNLinearCombinedEstimator.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import json -import tempfile -from tensorflow.contrib.layers.python.layers import feature_column -from tensorflow.contrib.learn.python.learn.datasets import base -from tensorflow.contrib.learn.python.learn.estimators import dnn_linear_combined -from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils -from tensorflow.contrib.learn.python.learn.estimators import run_config -from tensorflow.contrib.learn.python.learn.estimators import test_data -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import test -from tensorflow.python.training import adagrad -from tensorflow.python.training import ftrl -from tensorflow.python.training import server_lib - - -# Desired training steps, reported in benchmark. Actual steps might be slightly -# more than this since supervisor training runs for a non-detrministic number of -# steps. -_ITERS = 100 - -_METRIC_KEYS = { - 'accuracy', - 'auc', - 'accuracy/threshold_0.500000_mean', - 'loss', - 'precision/positive_threshold_0.500000_mean', - 'recall/positive_threshold_0.500000_mean', -} - - -class DNNLinearCombinedClassifierBenchmark(test.Benchmark): - - def _assertSingleClassMetrics(self, metrics): - estimator_test_utils.assert_in_range(0.9, 1.0, 'auc', metrics) - estimator_test_utils.assert_in_range(0.9, 1.0, - 'accuracy/threshold_0.500000_mean', - metrics) - estimator_test_utils.assert_in_range( - 0.9, 1.0, 'precision/positive_threshold_0.500000_mean', metrics) - estimator_test_utils.assert_in_range( - 0.9, 1.0, 'recall/positive_threshold_0.500000_mean', metrics) - self._assertCommonMetrics(metrics) - - def _assertCommonMetrics(self, metrics): - estimator_test_utils.assert_in_range(_ITERS, _ITERS + 5, 'global_step', - metrics) - estimator_test_utils.assert_in_range(0.9, 1.0, 'accuracy', metrics) - estimator_test_utils.assert_in_range(0.0, 0.2, 'loss', metrics) - self.report_benchmark( - iters=metrics['global_step'], - extras={k: v - for k, v in metrics.items() if k in _METRIC_KEYS}) - - def benchmarkMatrixData(self): - iris = test_data.prepare_iris_data_for_logistic_regression() - cont_feature = feature_column.real_valued_column('feature', dimension=4) - bucketized_feature = feature_column.bucketized_column( - cont_feature, test_data.get_quantile_based_buckets(iris.data, 10)) - - classifier = dnn_linear_combined.DNNLinearCombinedClassifier( - model_dir=tempfile.mkdtemp(), - linear_feature_columns=(bucketized_feature,), - dnn_feature_columns=(cont_feature,), - dnn_hidden_units=(3, 3)) - - input_fn = test_data.iris_input_logistic_fn - metrics = classifier.fit(input_fn=input_fn, steps=_ITERS).evaluate( - input_fn=input_fn, steps=100) - self._assertSingleClassMetrics(metrics) - - def benchmarkTensorData(self): - - def _input_fn(): - iris = test_data.prepare_iris_data_for_logistic_regression() - features = {} - for i in range(4): - # The following shows how to provide the Tensor data for - # RealValuedColumns. - features.update({ - str(i): - array_ops.reshape( - constant_op.constant( - iris.data[:, i], dtype=dtypes.float32), (-1, 1)) - }) - # The following shows how to provide the SparseTensor data for - # a SparseColumn. - features['dummy_sparse_column'] = sparse_tensor.SparseTensor( - values=('en', 'fr', 'zh'), - indices=((0, 0), (0, 1), (60, 0)), - dense_shape=(len(iris.target), 2)) - labels = array_ops.reshape( - constant_op.constant( - iris.target, dtype=dtypes.int32), (-1, 1)) - return features, labels - - iris = test_data.prepare_iris_data_for_logistic_regression() - cont_features = [ - feature_column.real_valued_column(str(i)) for i in range(4) - ] - linear_features = [ - feature_column.bucketized_column( - cont_features[i], - test_data.get_quantile_based_buckets(iris.data[:, i], 10)) - for i in range(4) - ] - linear_features.append( - feature_column.sparse_column_with_hash_bucket( - 'dummy_sparse_column', hash_bucket_size=100)) - - classifier = dnn_linear_combined.DNNLinearCombinedClassifier( - model_dir=tempfile.mkdtemp(), - linear_feature_columns=linear_features, - dnn_feature_columns=cont_features, - dnn_hidden_units=(3, 3)) - - metrics = classifier.fit(input_fn=_input_fn, steps=_ITERS).evaluate( - input_fn=_input_fn, steps=100) - self._assertSingleClassMetrics(metrics) - - def benchmarkCustomOptimizer(self): - iris = test_data.prepare_iris_data_for_logistic_regression() - cont_feature = feature_column.real_valued_column('feature', dimension=4) - bucketized_feature = feature_column.bucketized_column( - cont_feature, test_data.get_quantile_based_buckets(iris.data, 10)) - - classifier = dnn_linear_combined.DNNLinearCombinedClassifier( - model_dir=tempfile.mkdtemp(), - linear_feature_columns=(bucketized_feature,), - linear_optimizer=ftrl.FtrlOptimizer(learning_rate=0.1), - dnn_feature_columns=(cont_feature,), - dnn_hidden_units=(3, 3), - dnn_optimizer=adagrad.AdagradOptimizer(learning_rate=0.1)) - - input_fn = test_data.iris_input_logistic_fn - metrics = classifier.fit(input_fn=input_fn, steps=_ITERS).evaluate( - input_fn=input_fn, steps=100) - self._assertSingleClassMetrics(metrics) - - def benchmarkMultiClass(self): - iris = base.load_iris() - cont_feature = feature_column.real_valued_column('feature', dimension=4) - bucketized_feature = feature_column.bucketized_column( - cont_feature, test_data.get_quantile_based_buckets(iris.data, 10)) - - classifier = dnn_linear_combined.DNNLinearCombinedClassifier( - n_classes=3, - linear_feature_columns=(bucketized_feature,), - dnn_feature_columns=(cont_feature,), - dnn_hidden_units=(3, 3)) - - input_fn = test_data.iris_input_multiclass_fn - metrics = classifier.fit(input_fn=input_fn, steps=_ITERS).evaluate( - input_fn=input_fn, steps=100) - self._assertCommonMetrics(metrics) - - def benchmarkPartitionedVariables(self): - - def _input_fn(): - features = { - 'language': - sparse_tensor.SparseTensor( - values=('en', 'fr', 'zh'), - indices=((0, 0), (0, 1), (2, 0)), - dense_shape=(3, 2)) - } - labels = constant_op.constant(((1,), (0,), (0,))) - return features, labels - - # The given hash_bucket_size results in variables larger than the - # default min_slice_size attribute, so the variables are partitioned. - sparse_feature = feature_column.sparse_column_with_hash_bucket( - 'language', hash_bucket_size=2e7) - embedding_feature = feature_column.embedding_column( - sparse_feature, dimension=1) - - tf_config = { - 'cluster': { - run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1'] - } - } - with test.mock.patch.dict('os.environ', - {'TF_CONFIG': json.dumps(tf_config)}): - config = run_config.RunConfig() - # Because we did not start a distributed cluster, we need to pass an - # empty ClusterSpec, otherwise the device_setter will look for - # distributed jobs, such as "/job:ps" which are not present. - config._cluster_spec = server_lib.ClusterSpec({}) - - classifier = dnn_linear_combined.DNNLinearCombinedClassifier( - linear_feature_columns=(sparse_feature,), - dnn_feature_columns=(embedding_feature,), - dnn_hidden_units=(3, 3), - config=config) - - metrics = classifier.fit(input_fn=_input_fn, steps=_ITERS).evaluate( - input_fn=_input_fn, steps=100) - self._assertCommonMetrics(metrics) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 225d8796785..861db1f89ef 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -1070,8 +1070,8 @@ class _MultiClassHead(_SingleHead): labels_tensor = _to_labels_tensor(labels, self._label_name) _check_no_sparse_tensor(labels_tensor) if self._label_keys: - table = lookup_ops.index_table_from_tensor(self._label_keys, - name="label_id_lookup") + table = lookup_ops.index_table_from_tensor( + self._label_keys, name="label_id_lookup") return { "labels": labels_tensor, "label_ids": table.lookup(labels_tensor), diff --git a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py index 0f09b111bd8..896b668d4e2 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py +++ b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py @@ -178,7 +178,7 @@ def select_last_activations(activations, sequence_lengths): """Selects the nth set of activations for each n in `sequence_length`. Reuturns a `Tensor` of shape `[batch_size, k]`. If `sequence_length` is not - `None`, then `output[i, :] = activations[i, sequence_length[i], :]`. If + `None`, then `output[i, :] = activations[i, sequence_length[i] - 1, :]`. If `sequence_length` is `None`, then `output[i, :] = activations[i, -1, :]`. Args: diff --git a/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py b/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py index 9aa1e056284..6253f96315b 100644 --- a/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py +++ b/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py @@ -163,7 +163,7 @@ class ResamplerTest(test.TestCase): data_channels = 3 warp_width = 2 warp_height = 6 - batch_size = 10 + batch_size = 3 warp = _make_warp(batch_size, warp_height, warp_width, dtype.as_numpy_dtype) data_shape = (batch_size, data_height, data_width, data_channels) diff --git a/tensorflow/contrib/session_bundle/BUILD b/tensorflow/contrib/session_bundle/BUILD index 596c4f351ce..ebb7a218562 100644 --- a/tensorflow/contrib/session_bundle/BUILD +++ b/tensorflow/contrib/session_bundle/BUILD @@ -234,7 +234,7 @@ cc_library( cc_test( name = "session_bundle_test", - size = "small", + size = "medium", srcs = ["session_bundle_test.cc"], data = [":session_bundle_half_plus_two"], # Link in all registered kernels. diff --git a/tensorflow/contrib/session_bundle/session_bundle_test.cc b/tensorflow/contrib/session_bundle/session_bundle_test.cc index eb36d79e0f4..6d997bac9ee 100644 --- a/tensorflow/contrib/session_bundle/session_bundle_test.cc +++ b/tensorflow/contrib/session_bundle/session_bundle_test.cc @@ -171,7 +171,8 @@ void BasicTest(const string& export_path) { // SessionBundles. Concurrent with adding this test, we had a leak where the // TensorFlow Session was not being closed, which leaked memory. // TODO(b/31711147): Increase the SessionBundle ResourceLeakTest iterations and -// move outside of the test suite. +// move outside of the test suite; decrease test size back to small at the same +// time. TEST(LoadSessionBundleFromPath, ResourceLeakTest) { const string export_path = test_util::TestSrcDirPath(kExportPath); for (int i = 0; i < 100; i++) { diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD index bc305022642..527deab86a6 100644 --- a/tensorflow/contrib/summary/BUILD +++ b/tensorflow/contrib/summary/BUILD @@ -22,10 +22,12 @@ py_test( srcs_version = "PY2AND3", deps = [ ":summary_ops", + "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:lib", "//tensorflow/python:platform", "//tensorflow/python:training", - "//tensorflow/python/eager:context", + "//tensorflow/python/eager:function", "//tensorflow/python/eager:test", ], ) @@ -38,6 +40,7 @@ py_library( deps = [ ":gen_summary_ops", "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:summary_op_util", diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index 05e627adf1c..ceaf83b70a7 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -68,7 +68,8 @@ def never_record_summaries(): def create_summary_file_writer(logdir, max_queue=None, flush_secs=None, - filename_suffix=None): + filename_suffix=None, + name=None): """Creates a summary file writer in the current context.""" if max_queue is None: max_queue = constant_op.constant(10) @@ -76,7 +77,7 @@ def create_summary_file_writer(logdir, flush_secs = constant_op.constant(120) if filename_suffix is None: filename_suffix = constant_op.constant("") - resource = gen_summary_ops.summary_writer() + resource = gen_summary_ops.summary_writer(shared_name=name) gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue, flush_secs, filename_suffix) context.context().summary_writer_resource = resource @@ -84,76 +85,87 @@ def create_summary_file_writer(logdir, def _nothing(): """Convenient else branch for when summaries do not record.""" - return + return False + + +def summary_writer_function(name, tensor, function, family=None): + """Helper function to write summaries. + + Args: + name: name of the summary + tensor: main tensor to form the summary + function: function taking a tag and a scope which writes the summary + family: optional, the summary's family + + Returns: + The result of writing the summary. + """ + def record(): + with summary_op_util.summary_scope( + name, family, values=[tensor]) as (tag, scope): + function(tag, scope) + return True + + return control_flow_ops.cond(should_record_summaries(), record, _nothing) def generic(name, tensor, metadata, family=None): """Writes a tensor summary if possible.""" - def record(): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_summary(context.context().summary_writer_resource, - training_util.get_global_step(), tensor, - tag, metadata, name=scope) - return control_flow_ops.cond(should_record_summaries(), record, _nothing) + def function(tag, scope): + gen_summary_ops.write_summary(context.context().summary_writer_resource, + training_util.get_global_step(), tensor, + tag, metadata, name=scope) + return summary_writer_function(name, tensor, function, family=family) def scalar(name, tensor, family=None): """Writes a scalar summary if possible.""" - def record(): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_scalar_summary( - context.context().summary_writer_resource, - training_util.get_global_step(), tag, tensor, name=scope) + def function(tag, scope): + gen_summary_ops.write_scalar_summary( + context.context().summary_writer_resource, + training_util.get_global_step(), tag, tensor, name=scope) - return control_flow_ops.cond(should_record_summaries(), record, _nothing) + return summary_writer_function(name, tensor, function, family=family) def histogram(name, tensor, family=None): """Writes a histogram summary if possible.""" - def record(): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_histogram_summary( - context.context().summary_writer_resource, - training_util.get_global_step(), tag, tensor, name=scope) + def function(tag, scope): + gen_summary_ops.write_histogram_summary( + context.context().summary_writer_resource, + training_util.get_global_step(), tag, tensor, name=scope) - return control_flow_ops.cond(should_record_summaries(), record, _nothing) + return summary_writer_function(name, tensor, function, family=family) def image(name, tensor, bad_color=None, max_images=3, family=None): """Writes an image summary if possible.""" - def record(): + def function(tag, scope): if bad_color is None: bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8) - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_image_summary( - context.context().summary_writer_resource, - training_util.get_global_step(), tag, tensor, bad_color_, max_images, - name=scope) + gen_summary_ops.write_image_summary( + context.context().summary_writer_resource, + training_util.get_global_step(), tag, tensor, bad_color_, max_images, + name=scope) - return control_flow_ops.cond(should_record_summaries(), record, _nothing) + return summary_writer_function(name, tensor, function, family=family) def audio(name, tensor, sample_rate, max_outputs, family=None): """Writes an audio summary if possible.""" - def record(): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_audio_summary( - context.context().summary_writer_resource, - training_util.get_global_step(), - tag, - tensor, - sample_rate=sample_rate, - max_outputs=max_outputs, - name=scope) + def function(tag, scope): + gen_summary_ops.write_audio_summary( + context.context().summary_writer_resource, + training_util.get_global_step(), + tag, + tensor, + sample_rate=sample_rate, + max_outputs=max_outputs, + name=scope) - return control_flow_ops.cond(should_record_summaries(), record, _nothing) + return summary_writer_function(name, tensor, function, family=family) diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 56c1a16f7f0..4b1f60ce4eb 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -17,11 +17,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import tempfile from tensorflow.contrib.summary import summary_ops +from tensorflow.core.util import event_pb2 +from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import test_util +from tensorflow.python.lib.io import tf_record from tensorflow.python.platform import gfile from tensorflow.python.training import training_util @@ -36,7 +40,7 @@ class TargetTest(test_util.TensorFlowTestCase): def testSummaryOps(self): training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() - summary_ops.create_summary_file_writer(logdir, max_queue=0) + summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t0') summary_ops.always_record_summaries() summary_ops.generic('tensor', 1, '') summary_ops.scalar('scalar', 2.0) @@ -47,6 +51,27 @@ class TargetTest(test_util.TensorFlowTestCase): # test here that we're calling them correctly. self.assertTrue(gfile.Exists(logdir)) + def testDefunSummarys(self): + training_util.get_or_create_global_step() + logdir = tempfile.mkdtemp() + summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t1') + summary_ops.always_record_summaries() + + @function.defun + def write(): + summary_ops.scalar('scalar', 2.0) + + write() + + self.assertTrue(gfile.Exists(logdir)) + files = gfile.ListDirectory(logdir) + self.assertEqual(len(files), 1) + records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) + self.assertEqual(len(records), 2) + event = event_pb2.Event() + event.ParseFromString(records[1]) + self.assertEqual(event.summary.value[0].simple_value, 2.0) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc index ccc412600c7..e5d1beae7f9 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc @@ -96,7 +96,12 @@ void SplitCollectionOperator::AddExample( } bool SplitCollectionOperator::IsInitialized(int32 node_id) const { - return stats_.at(node_id)->IsInitialized(); + auto it = stats_.find(node_id); + if (it == stats_.end()) { + LOG(WARNING) << "IsInitialized called with unknown node_id = " << node_id; + return false; + } + return it->second->IsInitialized(); } void SplitCollectionOperator::CreateAndInitializeCandidateWithExample( diff --git a/tensorflow/contrib/tensorboard/plugins/trace/trace.py b/tensorflow/contrib/tensorboard/plugins/trace/trace.py index 57f95dfce72..07e5316b8b3 100644 --- a/tensorflow/contrib/tensorboard/plugins/trace/trace.py +++ b/tensorflow/contrib/tensorboard/plugins/trace/trace.py @@ -38,7 +38,7 @@ TOKENS = LEFT_TOKENS + RIGHT_TOKENS def store_trace_info(output_file_path, - graph=ops.get_default_graph(), + graph=None, ignore_regex_fpaths=None): """Collects and stores trace information for a TensorFlow model. @@ -51,6 +51,8 @@ def store_trace_info(output_file_path, in this list will be ignored. Defaults to patterns that match the core tensorflow python library. """ + graph = graph or ops.get_default_graph() + if not ignore_regex_fpaths: ignore_regex_fpaths = TF_LIB_REGEX_FPATHS diff --git a/tensorflow/contrib/tpu/profiler/op_profile.proto b/tensorflow/contrib/tpu/profiler/op_profile.proto index 6911b649a04..840a43913ba 100644 --- a/tensorflow/contrib/tpu/profiler/op_profile.proto +++ b/tensorflow/contrib/tpu/profiler/op_profile.proto @@ -32,6 +32,18 @@ message Node { string expression = 2; // %multiply = [shape]multiply(operand1, operand2) string provenance = 3; // Typically the TensorFlow operation name. string category = 4; + // Describes the physical memory layout of the instruction's primary input. + // e.g. for a convolution, this analyzes the image and ignores the kernel. + LayoutAnalysis layout = 5; + message LayoutAnalysis { + // The physical data layout, from most-minor to most-major dimensions. + repeated Dimension dimensions = 1; + message Dimension { + int32 size = 1; // Size of the data in this dimension. + int32 alignment = 2; // Data must be padded to a multiple of alignment. + string semantics = 3; // What the dimension represents, e.g. "spatial". + } + } } } diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 9db2ed830f4..93199283075 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -2380,6 +2380,7 @@ tf_cc_tests( "util/semver_test.cc", "util/sparse/sparse_tensor_test.cc", "util/stat_summarizer_test.cc", + "util/tensor_format_test.cc", "util/tensor_slice_reader_test.cc", "util/tensor_slice_set_test.cc", "util/tensor_slice_util_test.cc", diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/simple_graph_execution_state.cc index 2a974d18402..363d3a0c9d3 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc +++ b/tensorflow/core/common_runtime/simple_graph_execution_state.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/util.h" @@ -54,7 +53,6 @@ SimpleGraphExecutionState::SimpleGraphExecutionState( : stateful_placements_(options.stateful_placements), device_set_(options.device_set), session_options_(options.session_options), - costs_(true /*is_global*/), flib_def_(new FunctionLibraryDefinition(OpRegistry::Global(), graph_def->library())), graph_(nullptr) { @@ -258,19 +256,11 @@ Status SimpleGraphExecutionState::InitBaseGraph( // Save stateful placements before placing. RestoreStatefulNodes(new_graph.get()); - CostModel costs(true /*is_global*/); - { - mutex_lock l(mu_); - costs_.InitFromGraph(*new_graph); - costs.MergeFromGlobal(costs_); - } - GraphOptimizationPassOptions optimization_options; optimization_options.session_options = session_options_; optimization_options.graph = &new_graph; optimization_options.flib_def = flib_def_.get(); optimization_options.device_set = device_set_; - optimization_options.cost_model = &costs; TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( OptimizationPassRegistry::PRE_PLACEMENT, optimization_options)); @@ -420,14 +410,11 @@ Status SimpleGraphExecutionState::BuildGraph( new FunctionLibraryDefinition(*flib_def_)); // TODO(andydavis): Clarify optimization pass requirements around CostModel. - CostModel costs(true /*is_global*/); - costs.MergeFromGlobal(costs_); GraphOptimizationPassOptions optimization_options; optimization_options.session_options = session_options_; optimization_options.graph = &ng; optimization_options.flib_def = flib.get(); optimization_options.device_set = device_set_; - optimization_options.cost_model = &costs; TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options)); diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.h b/tensorflow/core/common_runtime/simple_graph_execution_state.h index c7f34a42d61..53eef8a07d5 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.h +++ b/tensorflow/core/common_runtime/simple_graph_execution_state.h @@ -25,19 +25,14 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/graph/costmodel.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { struct SessionOptions; -class StepStats; -class Timeline; namespace subgraph { struct RewriteGraphMetadata; @@ -167,7 +162,6 @@ class SimpleGraphExecutionState { // Returns the map of stateful placements as a map of // node name to placement string. std::unordered_map GetStatefulPlacements() const { - mutex_lock l(mu_); return stateful_placements_; } @@ -193,9 +187,6 @@ class SimpleGraphExecutionState { const DeviceSet* device_set_; // Not owned const SessionOptions* session_options_; // Not owned - mutable mutex mu_; - CostModel costs_ GUARDED_BY(mu_); - // Map from name to Node for the full graph in placed_. NodeNameToCostIdMap node_name_to_cost_id_map_; diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 0e3ea2ddfb5..ab21f47282e 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -206,15 +206,28 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c) { Status FusedConvBiasActivationShape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR(Conv2DShape(c)); - ShapeHandle bias_shape; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &bias_shape)); - DimensionHandle bias_dim = c->Dim(bias_shape, 0); + string data_format_str, filter_format_str; + TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); + TF_RETURN_IF_ERROR(c->GetAttr("filter_format", &filter_format_str)); + TensorFormat data_format; + FormatFromString(data_format_str, &data_format); + FilterTensorFormat filter_format; + FilterFormatFromString(filter_format_str, &filter_format); + + constexpr int num_spatial_dims = 2; + const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format); ShapeHandle filter_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape)); - DimensionHandle output_depth_dim = c->Dim(filter_shape, 3); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape)); + DimensionHandle output_depth_dim = c->Dim( + filter_shape, GetFilterDimIndex(filter_format, 'O')); int64 output_depth_dim_val = c->Value(output_depth_dim); + + ShapeHandle bias_shape; + // Bias should be a 1-D tensor. + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bias_shape)); + DimensionHandle bias_dim = c->Dim(bias_shape, 0); int64 bias_dim_val = c->Value(bias_dim); if (output_depth_dim_val != bias_dim_val) { @@ -223,6 +236,14 @@ Status FusedConvBiasActivationShape(shape_inference::InferenceContext* c) { ") and bias dimension (", bias_dim_val, ") do not match."); } + // Check side input shape matches the output shape. + ShapeHandle side_input_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &side_input_shape)); + if (c->Rank(side_input_shape) > 1) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->Merge(side_input_shape, c->output(0), &unused)); + } + return Status::OK(); } @@ -323,24 +344,38 @@ Status ShapeFromDimensions(DimensionHandle batch_dim, } Status Conv2DShape(shape_inference::InferenceContext* c) { - string data_format_str; - Status s = c->GetAttr("data_format", &data_format_str); - if (!s.ok()) { + string data_format_str, filter_format_str; + if (!c->GetAttr("data_format", &data_format_str).ok()) { data_format_str = "NHWC"; } + if (!c->GetAttr("filter_format", &filter_format_str).ok()) { + filter_format_str = "HWIO"; + } TensorFormat data_format; if (!FormatFromString(data_format_str, &data_format)) { return errors::InvalidArgument("Invalid data format string: ", data_format_str); } + FilterTensorFormat filter_format; + if (!FilterFormatFromString(filter_format_str, &filter_format)) { + return errors::InvalidArgument("Invalid filter format string: ", + filter_format_str); + } + + constexpr int num_spatial_dims = 2; + const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format); + ShapeHandle conv_input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &conv_input_shape)); + TF_RETURN_IF_ERROR(CheckFormatConstraintsOnShape( + data_format, conv_input_shape, "conv_input", c)); - const int rank = GetTensorDimsFromSpatialDims(2, data_format); - ShapeHandle input_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); // The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C). ShapeHandle filter_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape)); + TF_RETURN_IF_ERROR( + CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c)); + std::vector strides; TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); @@ -352,38 +387,33 @@ Status Conv2DShape(shape_inference::InferenceContext* c) { strides.size()); } - int32 stride_rows, stride_cols; - if (data_format == FORMAT_NCHW || data_format == FORMAT_NCHW_VECT_C) { - stride_rows = strides[2]; - stride_cols = strides[3]; - } else { - stride_rows = strides[1]; - stride_cols = strides[2]; - } + const int32 stride_rows = GetTensorDim(strides, data_format, 'H'); + const int32 stride_cols = GetTensorDim(strides, data_format, 'W'); DimensionHandle batch_size_dim; DimensionHandle input_depth_dim; gtl::InlinedVector input_spatial_dims(2); - TF_RETURN_IF_ERROR(DimensionsFromShape(input_shape, data_format, + TF_RETURN_IF_ERROR(DimensionsFromShape(conv_input_shape, data_format, &batch_size_dim, &input_spatial_dims, &input_depth_dim, c)); - DimensionHandle output_depth_dim, filter_rows_dim, filter_cols_dim, - filter_input_depth_dim; - // If the input format is NCHW_VECT_C, the filter format is assumed to be - // OIHW_VECT_I, otherwise it is assumed to be HWIO. - if (data_format == FORMAT_NCHW_VECT_C) { - output_depth_dim = c->Dim(filter_shape, 0); - TF_RETURN_IF_ERROR(c->Multiply(c->Dim(filter_shape, 1), - c->Dim(filter_shape, 4), - &filter_input_depth_dim)); - filter_rows_dim = c->Dim(filter_shape, 2); - filter_cols_dim = c->Dim(filter_shape, 3); + DimensionHandle output_depth_dim = c->Dim( + filter_shape, GetFilterDimIndex(filter_format, 'O')); + DimensionHandle filter_rows_dim = c->Dim( + filter_shape, GetFilterDimIndex(filter_format, 'H')); + DimensionHandle filter_cols_dim = c->Dim( + filter_shape, GetFilterDimIndex(filter_format, 'W')); + DimensionHandle filter_input_depth_dim; + if (filter_format == FORMAT_OIHW_VECT_I) { + TF_RETURN_IF_ERROR(c->Multiply( + c->Dim(filter_shape, + GetFilterDimIndex(filter_format, 'I')), + c->Dim(filter_shape, + GetFilterTensorInnerInputChannelsDimIndex(rank, filter_format)), + &filter_input_depth_dim)); } else { - filter_rows_dim = c->Dim(filter_shape, 0); - filter_cols_dim = c->Dim(filter_shape, 1); - filter_input_depth_dim = c->Dim(filter_shape, 2); - output_depth_dim = c->Dim(filter_shape, 3); + filter_input_depth_dim = c->Dim( + filter_shape, GetFilterDimIndex(filter_format, 'I')); } // Check that the input tensor and the filter tensor agree on the input @@ -559,9 +589,6 @@ Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) { } Status AvgPoolShape(shape_inference::InferenceContext* c) { - ShapeHandle input_shape; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 4, &input_shape)); - string data_format_str; TensorFormat data_format; Status s = c->GetAttr("data_format", &data_format_str); @@ -571,6 +598,10 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) { data_format = FORMAT_NHWC; } + const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; + ShapeHandle input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); + TF_RETURN_IF_ERROR( CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); @@ -627,9 +658,6 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) { } Status MaxPoolShape(shape_inference::InferenceContext* c) { - ShapeHandle input_shape; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 4, &input_shape)); - string data_format_str; TensorFormat data_format; Status s = c->GetAttr("data_format", &data_format_str); @@ -639,6 +667,10 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) { data_format = FORMAT_NHWC; } + const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; + ShapeHandle input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); + TF_RETURN_IF_ERROR( CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); @@ -696,11 +728,21 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) { } Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { - ShapeHandle input_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); + string data_format_str; + TensorFormat data_format; + Status s = c->GetAttr("data_format", &data_format_str); + if (s.ok()) { + FormatFromString(data_format_str, &data_format); + } else { + data_format = FORMAT_NHWC; + } - string data_format; - Status s = c->GetAttr("data_format", &data_format); + const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; + ShapeHandle input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); + + TF_RETURN_IF_ERROR( + CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); std::vector kernel_sizes; std::vector strides; @@ -725,7 +767,8 @@ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { } kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements()); auto kernel_sizes_vec = kernel_sizes_tensor->flat(); - std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(), kernel_sizes.begin()); + std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(), + kernel_sizes.begin()); const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1); if (strides_tensor == nullptr) { @@ -749,35 +792,22 @@ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { kernel_sizes.size()); } - int32 stride_rows, stride_cols, stride_depth; - int32 kernel_rows, kernel_cols, kernel_depth; + int32 stride_depth = GetTensorDim(strides, data_format, 'C'); + int32 stride_rows = GetTensorDim(strides, data_format, 'H'); + int32 stride_cols = GetTensorDim(strides, data_format, 'W'); + int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C'); + int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H'); + int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W'); - if (s.ok() && data_format == "NCHW") { - // Canonicalize input shape to NHWC so the shape inference code below can - // process it. - auto dim = [&](char dimension) { - return c->Dim(input_shape, GetTensorDimIndex<2>(FORMAT_NCHW, dimension)); - }; - input_shape = c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('C')}}); - stride_depth = strides[1]; - stride_rows = strides[2]; - stride_cols = strides[3]; - kernel_depth = kernel_sizes[1]; - kernel_rows = kernel_sizes[2]; - kernel_cols = kernel_sizes[3]; - } else { - stride_rows = strides[1]; - stride_cols = strides[2]; - stride_depth = strides[3]; - kernel_rows = kernel_sizes[1]; - kernel_cols = kernel_sizes[2]; - kernel_depth = kernel_sizes[3]; - } - - DimensionHandle batch_size_dim = c->Dim(input_shape, 0); - DimensionHandle in_rows_dim = c->Dim(input_shape, 1); - DimensionHandle in_cols_dim = c->Dim(input_shape, 2); - DimensionHandle in_depth_dim = c->Dim(input_shape, 3); + constexpr int num_spatial_dims = 2; + DimensionHandle batch_size_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'N')); + DimensionHandle in_rows_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'H')); + DimensionHandle in_cols_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'W')); + DimensionHandle in_depth_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'C')); Padding padding; TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); @@ -791,15 +821,9 @@ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth)); - output_shape = - c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth}); - if (data_format == "NCHW") { - // Convert output shape back to expected NCHW data format. - auto dim = [&](char dimension) { - return c->Dim(output_shape, GetTensorDimIndex<2>(FORMAT_NHWC, dimension)); - }; - output_shape = c->MakeShape({{dim('N'), dim('C'), dim('0'), dim('1')}}); - } + TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim, + {output_rows, output_cols}, + output_depth, &output_shape, c)); c->set_output(0, output_shape); return Status::OK(); diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index 14f6c1bb45f..ec9746b2af1 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/shape_inference_testutil.h" @@ -411,34 +412,35 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { TEST(CommonShapeFnsTest, Conv2DShapeTest) { ShapeInferenceTestOp op("Conv2D"); auto set_op = [&op](const std::vector& strides, const string& padding, - const string& data_format) { + const string& data_format, const string& filter_format) { TF_CHECK_OK(NodeDefBuilder("test", "Conv2D") .Input("input", 0, DT_FLOAT) .Input("filter", 0, DT_FLOAT) .Attr("strides", strides) .Attr("padding", padding) .Attr("data_format", data_format) + .Attr("filter_format", filter_format) .Finalize(&op.node_def)); }; // 1x1 filter - set_op({{1, 1, 1, 1}}, "VALID", "NHWC"); + set_op({{1, 1, 1, 1}}, "VALID", "NHWC", "HWIO"); INFER_OK(op, "[1,2,2,1];[1,1,1,1]", "[d0_0,2,2,d1_3]"); // 2x2 filter - set_op({{1, 1, 1, 1}}, "VALID", "NHWC"); + set_op({{1, 1, 1, 1}}, "VALID", "NHWC", "HWIO"); INFER_OK(op, "[1,2,2,1];[2,2,1,1]", "[d0_0,1,1,d1_3]"); // 3x3 input, 1x1 filter, 2x2 stride - set_op({{1, 2, 2, 1}}, "VALID", "NHWC"); + set_op({{1, 2, 2, 1}}, "VALID", "NHWC", "HWIO"); INFER_OK(op, "[1,3,3,1];[1,1,1,1]", "[d0_0,2,2,d1_3]"); // 3x3 input, 1x1 filter, 2x1 stride - set_op({{1, 2, 1, 1}}, "VALID", "NHWC"); + set_op({{1, 2, 1, 1}}, "VALID", "NHWC", "HWIO"); INFER_OK(op, "[1,3,3,1];[1,1,1,1]", "[d0_0,2,3,d1_3]"); // 4x4 input, 2x1 filter, 1x2 stride - set_op({{1, 1, 2, 1}}, "VALID", "NHWC"); + set_op({{1, 1, 2, 1}}, "VALID", "NHWC", "HWIO"); INFER_OK(op, "[1,4,4,1];[2,1,1,1]", "[d0_0,3,2,d1_3]"); // Invalid rank for input @@ -460,77 +462,76 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) { // Tests for NCHW // 1x1 filter - set_op({{1, 1, 1, 1}}, "VALID", "NCHW"); + set_op({{1, 1, 1, 1}}, "VALID", "NCHW", "HWIO"); INFER_OK(op, "[1,1,2,2];[1,1,1,1]", "[d0_0,d1_3,2,2]"); // 2x2 filter - set_op({{1, 1, 1, 1}}, "VALID", "NCHW"); + set_op({{1, 1, 1, 1}}, "VALID", "NCHW", "HWIO"); INFER_OK(op, "[1,1,2,2];[2,2,1,1]", "[d0_0,d1_3,1,1]"); // 3x3 input, 1x1 filter, 2x2 stride - set_op({{1, 1, 2, 2}}, "VALID", "NCHW"); + set_op({{1, 1, 2, 2}}, "VALID", "NCHW", "HWIO"); INFER_OK(op, "[1,1,3,3];[1,1,1,1]", "[d0_0,d1_3,2,2]"); // 3x3 input, 1x1 filter, 2x1 stride - set_op({{1, 1, 2, 1}}, "VALID", "NCHW"); + set_op({{1, 1, 2, 1}}, "VALID", "NCHW", "HWIO"); INFER_OK(op, "[1,1,3,3];[1,1,1,1]", "[d0_0,d1_3,2,3]"); // 4x4 input, 2x1 filter, 1x2 stride - set_op({{1, 1, 1, 2}}, "VALID", "NCHW"); + set_op({{1, 1, 1, 2}}, "VALID", "NCHW", "HWIO"); INFER_OK(op, "[1,1,4,4];[2,1,1,1]", "[d0_0,d1_3,3,2]"); // Tests for NCHW_VECT_C // 1x1 filter - set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C"); + set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I"); INFER_OK(op, "[1,1,2,2,4];[4,1,1,1,4]", "[d0_0,1,2,2,4]"); // 2x2 filter - set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C"); + set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I"); INFER_OK(op, "[1,1,2,2,4];[4,1,2,2,4]", "[d0_0,1,1,1,4]"); // 3x3 input, 1x1 filter, 2x2 stride - set_op({{1, 1, 2, 2}}, "VALID", "NCHW_VECT_C"); + set_op({{1, 1, 2, 2}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I"); INFER_OK(op, "[1,1,3,3,4];[8,1,1,1,4]", "[d0_0,2,2,2,4]"); // 3x3 input, 1x1 filter, 2x1 stride - set_op({{1, 1, 2, 1}}, "VALID", "NCHW_VECT_C"); + set_op({{1, 1, 2, 1}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I"); INFER_OK(op, "[1,1,3,3,4];[4,1,1,1,4]", "[d0_0,1,2,3,4]"); // 4x4 input, 2x1 filter, 1x2 stride - set_op({{1, 1, 1, 2}}, "VALID", "NCHW_VECT_C"); + set_op({{1, 1, 1, 2}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I"); INFER_OK(op, "[1,1,4,4,4];[4,1,2,1,4]", "[d0_0,1,3,2,4]"); // Some tests for "SAME" padding // 4x4 input, 1x1 filter, 1x1 stride - set_op({{1, 1, 1, 1}}, "SAME", "NHWC"); + set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO"); INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,d0_1,d0_2,d1_3]"); // 3x3 input, 2x2 filter, 1x1 stride - set_op({{1, 1, 1, 1}}, "SAME", "NHWC"); + set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO"); INFER_OK(op, "[1,3,3,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]"); // 4x4 input, 2x2 filter, 2x2 stride - set_op({{1, 2, 2, 1}}, "SAME", "NHWC"); + set_op({{1, 2, 2, 1}}, "SAME", "NHWC", "HWIO"); INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,2,2,d1_3]"); // 4x4 input, 2x2 filter, 1x1 stride - set_op({{1, 1, 1, 1}}, "SAME", "NHWC"); + set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO"); INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]"); // With stride 1x1 and SAME, unknown dims don't matter - filter dims except // for output channels are ignored for output, so all inputs are carried // through to output. - set_op({{1, 1, 1, 1}}, "SAME", "NHWC"); + set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO"); INFER_OK(op, "[1,4,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); INFER_OK(op, "[1,?,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); INFER_OK(op, "[1,4,?,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); INFER_OK(op, "[1,4,4,?];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); - INFER_OK(op, "[1,4,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); - INFER_OK(op, "[1,4,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); + INFER_OK(op, "[?,4,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); // With stride != 1, the input HW dims are divided to produce output dims. - set_op({{1, 2, 2, 1}}, "SAME", "NHWC"); + set_op({{1, 2, 2, 1}}, "SAME", "NHWC", "HWIO"); INFER_OK(op, "[?,4,4,1];[?,?,?,?]", "[d0_0,2,2,d1_3]"); INFER_OK(op, "[1,?,4,1];[?,?,?,?]", "[d0_0,?,2,d1_3]"); INFER_OK(op, "[1,4,?,1];[?,?,?,?]", "[d0_0,2,?,d1_3]"); @@ -704,7 +705,7 @@ TEST(CommonShapeFnsTest, AvgPool2DShapeTest) { INFER_ERROR("Dimension must be 4 but is 3", op, "[2,5,7,11,3]"); // Invalid rank for input - INFER_ERROR("must be at least rank 4", op, "[4,4]"); + INFER_ERROR("Shape must be rank", op, "[4,4]"); } TEST(CommonShapeFnsTest, MaxPool2DShapeTest) { @@ -741,6 +742,48 @@ TEST(CommonShapeFnsTest, MaxPool2DShapeTest) { INFER_ERROR("Dimension must be 4 but is 8", op, "[2,3,5,7,8]"); } +TEST(CommonShapeFnsTest, MaxPoolV22DShapeTest) { + ShapeInferenceTestOp op("MaxPoolV2"); + Tensor ksizes_tensor, strides_tensor; + auto set_op = [&op, &ksizes_tensor, &strides_tensor]( + const std::vector& strides, + const std::vector& ksizes, const string& padding, + const string& data_format) { + TF_CHECK_OK(NodeDefBuilder("test", "MaxPoolV2") + .Input("input", 0, DT_FLOAT) + .Input("ksize", 1, DT_INT32) + .Input("strides", 2, DT_INT32) + .Attr("padding", padding) + .Attr("data_format", data_format) + .Finalize(&op.node_def)); + ksizes_tensor = test::AsTensor(ksizes); + op.input_tensors.resize(3); + op.input_tensors[0] = nullptr; + op.input_tensors[1] = &ksizes_tensor; + strides_tensor = test::AsTensor(strides); + op.input_tensors[2] = &strides_tensor; + }; + + // Most of the functionality is tested by conv-like shapes, + // so we check the very-specific maxpooling features here, + // namely depthwise kernel and striding. + + // all 1 strides, depth 2 filter + set_op({1, 1, 1, 1}, {1, 1, 1, 2}, "VALID", "NHWC"); + INFER_OK(op, "[1,2,2,2];[4];[4]", "[d0_0,2,2,1]"); + + // depth 3 stride, 1x1x1 filter, NCHW + set_op({1, 3, 1, 1}, {1, 1, 1, 1}, "VALID", "NCHW"); + INFER_OK(op, "[1,7,5,5];[4];[4]", "[d0_0,3,5,5]"); + + // 5x7 input, 2x2 ksize, 1x1 stride, NCHW_VECT_C tests + set_op({{1, 1, 1, 1}}, {1, 1, 2, 2}, "SAME", "NCHW_VECT_C"); + INFER_OK(op, "[2,3,5,7,4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3,4]"); + INFER_OK(op, "[5,7,?,?,4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3,4]"); + INFER_OK(op, "[?,?,?,?,4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3,4]"); + INFER_ERROR("Dimension must be 4 but is 8", op, "[2,3,5,7,8];[4];[4]"); +} + TEST(CommonShapeFnsTest, Pool3DShapeTest) { ShapeInferenceTestOp op("MaxPool3D"); auto set_op = [&op](const std::vector& strides, diff --git a/tensorflow/core/framework/summary.proto b/tensorflow/core/framework/summary.proto index ba490333310..55879f87831 100644 --- a/tensorflow/core/framework/summary.proto +++ b/tensorflow/core/framework/summary.proto @@ -42,7 +42,7 @@ message SummaryMetadata { // The content to store for the plugin. The best practice is for this to be // a binary serialized protocol buffer. - string content = 2; + bytes content = 2; } // Data that associates a summary with a certain plugin. diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 4c793231974..cf5d6e8baac 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -1205,12 +1205,12 @@ int MklLayoutRewritePass::SetUpContiguousInputs( if (do_connect_conv2d_backprop_input_filter && iidx == kConv2DBackpropInputFilterInputSlotIdx) { GetNodeProducingMklTensor(g, old_node, conv2d_node, - kConv2DFilterOutputSlotIdx, - &mkl_node, &mkl_node_output_slot); + kConv2DFilterOutputSlotIdx, &mkl_node, + &mkl_node_output_slot); } else { GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first, - old_node_inputs[iidx].second, - &mkl_node, &mkl_node_output_slot); + old_node_inputs[iidx].second, &mkl_node, + &mkl_node_output_slot); } nb->Input(mkl_node, mkl_node_output_slot); iidx++; diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index efc5d7c553a..32a1b2c84d4 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -20,6 +20,7 @@ package_group( packages = [ "//learning/brain/contrib/...", "//learning/brain/research/sparse_matrix/...", + "//learning/faster_training/...", "//tensorflow/...", ], ) @@ -3350,10 +3351,9 @@ tf_cc_test( srcs = ["parse_tensor_test.cc"], deps = [ ":ops_testutil", - ":ops_util", ":parse_tensor_op", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", - "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", ], @@ -5587,6 +5587,20 @@ cc_library( ], ) +cc_library( + name = "dataset_utils", + srcs = ["dataset_utils.cc"], + hdrs = ["dataset_utils.h"], + deps = [ + ":captured_function", + ":dataset", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/util/tensor_bundle", + ], +) + cc_library( name = "captured_function", srcs = ["captured_function.cc"], @@ -5713,6 +5727,7 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", + ":dataset_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -5727,6 +5742,22 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", + ":dataset_utils", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_kernel_library( + name = "sloppy_interleave_dataset_op", + srcs = ["sloppy_interleave_dataset_op.cc"], + deps = [ + ":captured_function", + ":dataset", + ":dataset_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -5963,6 +5994,7 @@ tf_kernel_library( ":repeat_dataset_op", ":shuffle_dataset_op", ":skip_dataset_op", + ":sloppy_interleave_dataset_op", ":sparse_tensor_slice_dataset_op", ":sql_dataset_ops", ":take_dataset_op", diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h index 168cf37bc77..c852dc9991c 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.h +++ b/tensorflow/core/kernels/conv_ops_gpu.h @@ -92,11 +92,11 @@ class ConvParameters { ConvParameters(int64 batch, int64 in_depths, const SpatialArray& in, int64 out_depths, const SpatialArray& filter, const SpatialArray& stride, const SpatialArray& padding, - const DataType& dtype, int device_id) + DataType dtype, int device_id) : batch_(batch), in_depths_(in_depths), - in_(in), out_depths_(out_depths), + in_(in), filter_(filter), stride_(stride), padding_(padding), @@ -130,7 +130,8 @@ class ConvParameters { "(", str_util::Join(filter_, ", "), "), ", "(", str_util::Join(stride_, ", "), "), ", "(", str_util::Join(padding_, ", "), "), ", - dtype_, ", ", device_id_); + dtype_, ", ", + device_id_); // clang-format on } @@ -150,26 +151,28 @@ class ConvParameters { } } - private: - typedef std::tuple - ParameterDataType; + protected: + using ParameterDataType = + std::tuple; ParameterDataType get_data_as_tuple() const { return std::make_tuple(batch_, in_depths_, in_, out_depths_, filter_, stride_, padding_, dtype_, device_id_); } + uint64 hash_code_; + + private: int64 batch_; int64 in_depths_; - SpatialArray in_; int64 out_depths_; + SpatialArray in_; SpatialArray filter_; SpatialArray stride_; SpatialArray padding_; DataType dtype_; int device_id_; - uint64 hash_code_; }; typedef Eigen::GpuDevice GPUDevice; diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc index 2307c2de0e6..3d4670c9bae 100644 --- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc @@ -556,6 +556,7 @@ template struct functor::NCHWToNHWC; template struct functor::NCHWToNHWC; template struct functor::NCHWToNHWC; +template struct functor::PadInput; template struct functor::PadInput; template struct functor::PadInput; diff --git a/tensorflow/core/kernels/crop_and_resize_op.cc b/tensorflow/core/kernels/crop_and_resize_op.cc index 56181a686ca..45cc2fbbb8b 100644 --- a/tensorflow/core/kernels/crop_and_resize_op.cc +++ b/tensorflow/core/kernels/crop_and_resize_op.cc @@ -19,59 +19,98 @@ limitations under the License. #include "tensorflow/core/kernels/crop_and_resize_op.h" +#include +#include + #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/work_sharder.h" #if GOOGLE_CUDA +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/platform/cuda.h" #include "tensorflow/core/platform/stream_executor.h" + +using ::perftools::gputools::cuda::ScopedActivateExecutorContext; #endif // GOOGLE_CUDA namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +using Callback = std::function; -static inline void ParseAndCheckBoxSizes(OpKernelContext* context, - const Tensor& boxes, - const Tensor& box_ind, - int* num_boxes) { - if (boxes.NumElements() == 0 && box_ind.NumElements() == 0) { +namespace { + +static inline Status ParseAndCheckBoxSizes(const Tensor& boxes, + const Tensor& box_index, + int* num_boxes) { + if (boxes.NumElements() == 0 && box_index.NumElements() == 0) { *num_boxes = 0; - return; + return Status::OK(); } // The shape of 'boxes' is [num_boxes, 4]. - OP_REQUIRES(context, boxes.dims() == 2, - errors::InvalidArgument("boxes must be 2-D", - boxes.shape().DebugString())); + if (boxes.dims() != 2) { + return errors::InvalidArgument("boxes must be 2-D", + boxes.shape().DebugString()); + } *num_boxes = boxes.dim_size(0); - OP_REQUIRES(context, boxes.dim_size(1) == 4, - errors::InvalidArgument("boxes must have 4 columns")); - - // The shape of 'box_ind' is [num_boxes]. - OP_REQUIRES(context, box_ind.dims() == 1, - errors::InvalidArgument("box_ind must be 1-D", - box_ind.shape().DebugString())); - OP_REQUIRES(context, box_ind.dim_size(0) == *num_boxes, - errors::InvalidArgument("box_ind has incompatible shape")); + if (boxes.dim_size(1) != 4) { + return errors::InvalidArgument("boxes must have 4 columns"); + } + // The shape of 'box_index' is [num_boxes]. + if (box_index.dims() != 1) { + return errors::InvalidArgument("box_index must be 1-D", + box_index.shape().DebugString()); + } + if (box_index.dim_size(0) != *num_boxes) { + return errors::InvalidArgument("box_index has incompatible shape"); + } + return Status::OK(); } -// Verifies that all values in box_ind are in [0, batch). +// Conditionally calls the compute callback if all values in box_index are in +// [0, batch_size) then calls done. template -inline void CheckValidBoxInd( - OpKernelContext* context, - typename TTypes::ConstTensor box_ind_data, int batch); +inline void RunIfBoxIndexIsValid( + OpKernelContext* context, typename TTypes::ConstTensor box_index, + int batch_size, const Callback& compute, const Callback& done); + +// Specialization of CheckValidBoxIndex for a CPUDevice. +template <> +inline void RunIfBoxIndexIsValid( + OpKernelContext* context, typename TTypes::ConstTensor box_index, + int batch_size, const Callback& compute, const Callback& done) { + const int num_boxes = box_index.dimension(0); + for (int b = 0; b < num_boxes; ++b) { + OP_REQUIRES_ASYNC( + context, FastBoundsCheck(box_index(b), batch_size), + errors::OutOfRange("box_index has values outside [0, batch_size)"), + done); + } + if (compute) { + compute(); + } + if (done) { + done(); + } +} + +} // namespace template -class CropAndResizeOp : public OpKernel { +class CropAndResizeOp : public AsyncOpKernel { public: - explicit CropAndResizeOp(OpKernelConstruction* context) : OpKernel(context) { + explicit CropAndResizeOp(OpKernelConstruction* context) + : AsyncOpKernel(context) { string method; OP_REQUIRES_OK(context, context->GetAttr("method", &method)); OP_REQUIRES(context, method == "bilinear", @@ -80,69 +119,77 @@ class CropAndResizeOp : public OpKernel { &extrapolation_value_)); } - void Compute(OpKernelContext* context) override { - // The shape of 'image' is [batch, image_height, image_width, channels]. + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + // The shape of 'image' is [batch_size, image_height, image_width, + // channels]. const Tensor& image = context->input(0); - OP_REQUIRES(context, image.dims() == 4, - errors::InvalidArgument("input image must be 4-D", - image.shape().DebugString())); - - const int batch = image.dim_size(0); - const int image_height = image.dim_size(1); - const int image_width = image.dim_size(2); - const int depth = image.dim_size(3); - OP_REQUIRES(context, image_height > 0 && image_width > 0, - errors::InvalidArgument("image dimensions must be positive")); - // The shape of 'boxes' is [num_boxes, 4]. const Tensor& boxes = context->input(1); - - // The shape of 'box_ind' is [num_boxes]. - const Tensor& box_ind = context->input(2); - - int num_boxes = 0; - ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes); - + // The shape of 'box_index' is [num_boxes]. + const Tensor& box_index = context->input(2); // The shape of 'crop_size' is [2]. const Tensor& crop_size = context->input(3); - OP_REQUIRES(context, crop_size.dims() == 1, - errors::InvalidArgument("crop_size must be 1-D", - crop_size.shape().DebugString())); - OP_REQUIRES(context, crop_size.dim_size(0) == 2, - errors::InvalidArgument("crop_size must have two elements", - crop_size.shape().DebugString())); + // Validate inputs dimensions. + OP_REQUIRES_ASYNC(context, image.dims() == 4, + errors::InvalidArgument("input image must be 4-D", + image.shape().DebugString()), + done); + const int batch_size = image.dim_size(0); + const int image_height = image.dim_size(1); + const int image_width = image.dim_size(2); + const int depth = image.dim_size(3); + OP_REQUIRES_ASYNC( + context, image_height > 0 && image_width > 0, + errors::InvalidArgument("image dimensions must be positive"), done); + int num_boxes = 0; + OP_REQUIRES_OK_ASYNC( + context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done); + OP_REQUIRES_ASYNC(context, crop_size.dims() == 1, + errors::InvalidArgument("crop_size must be 1-D", + crop_size.shape().DebugString()), + done); + OP_REQUIRES_ASYNC( + context, crop_size.dim_size(0) == 2, + errors::InvalidArgument("crop_size must have two elements", + crop_size.shape().DebugString()), + done); + + // Copy and validate crop sizes. auto crop_size_vec = crop_size.vec(); const int crop_height = internal::SubtleMustCopy(crop_size_vec(0)); const int crop_width = internal::SubtleMustCopy(crop_size_vec(1)); - OP_REQUIRES(context, crop_height > 0 && crop_width > 0, - errors::InvalidArgument("crop dimensions must be positive")); + OP_REQUIRES_ASYNC( + context, crop_height > 0 && crop_width > 0, + errors::InvalidArgument("crop dimensions must be positive"), done); // Allocate output tensor. Tensor* output = nullptr; - OP_REQUIRES_OK( + OP_REQUIRES_OK_ASYNC( context, context->allocate_output( 0, TensorShape({num_boxes, crop_height, crop_width, depth}), - &output)); + &output), + done); - typename TTypes::ConstTensor image_data = image.tensor(); - typename TTypes::ConstTensor boxes_data = - boxes.tensor(); - typename TTypes::ConstTensor box_ind_data = - box_ind.tensor(); - typename TTypes::Tensor crops_data = output->tensor(); + auto compute_callback = [this, context, output]() { + const Tensor& image = context->input(0); + const Tensor& boxes = context->input(1); + const Tensor& box_index = context->input(2); + const bool status = functor::CropAndResize()( + context, image.tensor(), boxes.tensor(), + box_index.tensor(), extrapolation_value_, + output->tensor()); + if (!status) { + context->SetStatus( + errors::Internal("Failed launch CropAndResizeKernel.")); + } + }; - CheckValidBoxInd(context, box_ind_data, batch); - - bool status = functor::CropAndResize()( - context, image_data, boxes_data, box_ind_data, extrapolation_value_, - crops_data); - if (!status) { - context->SetStatus( - errors::Internal("Failed launch CropAndResizeKernel.")); - } + RunIfBoxIndexIsValid(context, box_index.tensor(), + batch_size, std::move(compute_callback), + std::move(done)); } private: @@ -156,10 +203,10 @@ struct CropAndResize { bool operator()(const OpKernelContext* context, typename TTypes::ConstTensor image, typename TTypes::ConstTensor boxes, - typename TTypes::ConstTensor box_ind, + typename TTypes::ConstTensor box_index, float extrapolation_value, typename TTypes::Tensor crops) { - const int batch = image.dimension(0); + const int batch_size = image.dimension(0); const int image_height = image.dimension(1); const int image_width = image.dimension(2); @@ -176,8 +223,8 @@ struct CropAndResize { const float y2 = boxes(b, 2); const float x2 = boxes(b, 3); - const int32 b_in = box_ind(b); - if (b_in < 0 || b_in >= batch) { + const int32 b_in = box_index(b); + if (!FastBoundsCheck(b_in, batch_size)) { continue; } @@ -255,89 +302,94 @@ struct CropAndResize { return true; } }; + } // namespace functor template -class CropAndResizeGradImageOp : public OpKernel { +class CropAndResizeGradImageOp : public AsyncOpKernel { public: explicit CropAndResizeGradImageOp(OpKernelConstruction* context) - : OpKernel(context) { + : AsyncOpKernel(context) { string method; OP_REQUIRES_OK(context, context->GetAttr("method", &method)); OP_REQUIRES(context, method == "bilinear", errors::InvalidArgument("method must be 'bilinear'", method)); } - void Compute(OpKernelContext* context) override { + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth]. const Tensor& grads = context->input(0); - - OP_REQUIRES(context, grads.dims() == 4, - errors::InvalidArgument("grads image must be 4-D", - grads.shape().DebugString())); - const int crop_height = grads.dim_size(1); - const int crop_width = grads.dim_size(2); - OP_REQUIRES(context, crop_height > 0 && crop_width > 0, - errors::InvalidArgument("grads dimensions must be positive")); - // The shape of 'boxes' is [num_boxes, 4]. const Tensor& boxes = context->input(1); - - // The shape of 'box_ind' is [num_boxes]. - const Tensor& box_ind = context->input(2); - - int num_boxes = 0; - ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes); - - OP_REQUIRES( - context, grads.dim_size(0) == num_boxes, - errors::InvalidArgument("boxes and grads have incompatible shape")); - + // The shape of 'box_index' is [num_boxes]. + const Tensor& box_index = context->input(2); // The shape of 'image_size' is [4]. const Tensor& image_size = context->input(3); - OP_REQUIRES(context, image_size.dims() == 1, - errors::InvalidArgument("image_size must be 1-D", - image_size.shape().DebugString())); - OP_REQUIRES(context, image_size.dim_size(0) == 4, - errors::InvalidArgument("image_size must have 4 elements", - image_size.shape().DebugString())); + // Validate input shapes. + OP_REQUIRES_ASYNC(context, grads.dims() == 4, + errors::InvalidArgument("grads image must be 4-D", + grads.shape().DebugString()), + done); + const int crop_height = grads.dim_size(1); + const int crop_width = grads.dim_size(2); + OP_REQUIRES_ASYNC( + context, crop_height > 0 && crop_width > 0, + errors::InvalidArgument("grads dimensions must be positive"), done); + int num_boxes = 0; + OP_REQUIRES_OK_ASYNC( + context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done); + OP_REQUIRES_ASYNC( + context, grads.dim_size(0) == num_boxes, + errors::InvalidArgument("boxes and grads have incompatible shape"), + done); + + OP_REQUIRES_ASYNC(context, image_size.dims() == 1, + errors::InvalidArgument("image_size must be 1-D", + image_size.shape().DebugString()), + done); + OP_REQUIRES_ASYNC(context, image_size.dim_size(0) == 4, + errors::InvalidArgument("image_size must have 4 elements", + image_size.shape().DebugString()), + done); auto image_size_vec = image_size.vec(); - const int batch = internal::SubtleMustCopy(image_size_vec(0)); + const int batch_size = internal::SubtleMustCopy(image_size_vec(0)); const int image_height = internal::SubtleMustCopy(image_size_vec(1)); const int image_width = internal::SubtleMustCopy(image_size_vec(2)); const int depth = internal::SubtleMustCopy(image_size_vec(3)); - - OP_REQUIRES(context, image_height > 0 && image_width > 0, - errors::InvalidArgument("image dimensions must be positive")); - OP_REQUIRES( + OP_REQUIRES_ASYNC( + context, image_height > 0 && image_width > 0, + errors::InvalidArgument("image dimensions must be positive"), done); + OP_REQUIRES_ASYNC( context, grads.dim_size(3) == depth, - errors::InvalidArgument("image_size and grads are incompatible")); + errors::InvalidArgument("image_size and grads are incompatible"), done); // Allocate output tensor. Tensor* output = nullptr; - OP_REQUIRES_OK( - context, context->allocate_output( - 0, TensorShape({batch, image_height, image_width, depth}), - &output)); + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_output( + 0, TensorShape({batch_size, image_height, image_width, depth}), + &output), + done); - typename TTypes::ConstTensor grads_data = - grads.tensor(); - typename TTypes::ConstTensor boxes_data = - boxes.tensor(); - typename TTypes::ConstTensor box_ind_data = - box_ind.tensor(); - typename TTypes::Tensor output_data = output->tensor(); + auto compute_callback = [context, output]() { + const Tensor& grads = context->input(0); + const Tensor& boxes = context->input(1); + const Tensor& box_index = context->input(2); + const bool status = functor::CropAndResizeBackpropImage()( + context->eigen_device(), grads.tensor(), + boxes.tensor(), box_index.tensor(), + output->tensor()); + if (!status) { + context->SetStatus(errors::Internal( + "Failed launch CropAndResizeBackpropImage kernel.")); + } + }; - CheckValidBoxInd(context, box_ind_data, batch); - - bool status = functor::CropAndResizeBackpropImage()( - context->eigen_device(), grads_data, boxes_data, box_ind_data, - output_data); - if (!status) { - context->SetStatus( - errors::Internal("Failed launch CropAndResizeBackpropImageKernel.")); - } + RunIfBoxIndexIsValid(context, box_index.tensor(), + batch_size, std::move(compute_callback), + std::move(done)); } }; @@ -348,9 +400,9 @@ struct CropAndResizeBackpropImage { bool operator()(const CPUDevice& d, typename TTypes::ConstTensor grads, typename TTypes::ConstTensor boxes, - typename TTypes::ConstTensor box_ind, + typename TTypes::ConstTensor box_index, typename TTypes::Tensor grads_image) { - const int batch = grads_image.dimension(0); + const int batch_size = grads_image.dimension(0); const int image_height = grads_image.dimension(1); const int image_width = grads_image.dimension(2); @@ -367,8 +419,8 @@ struct CropAndResizeBackpropImage { const float y2 = boxes(b, 2); const float x2 = boxes(b, 3); - const int32 b_in = box_ind(b); - if (b_in < 0 || b_in >= batch) { + const int32 b_in = box_index(b); + if (!FastBoundsCheck(b_in, batch_size)) { continue; } @@ -419,83 +471,90 @@ struct CropAndResizeBackpropImage { return true; } }; + } // namespace functor template -class CropAndResizeGradBoxesOp : public OpKernel { +class CropAndResizeGradBoxesOp : public AsyncOpKernel { public: explicit CropAndResizeGradBoxesOp(OpKernelConstruction* context) - : OpKernel(context) { + : AsyncOpKernel(context) { string method; OP_REQUIRES_OK(context, context->GetAttr("method", &method)); OP_REQUIRES(context, method == "bilinear", errors::InvalidArgument("method must be 'bilinear'", method)); } - void Compute(OpKernelContext* context) override { + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth]. const Tensor& grads = context->input(0); + // The shape of 'boxes' is [num_boxes, 4]. + const Tensor& boxes = context->input(2); + // The shape of 'box_index' is [num_boxes]. + const Tensor& box_index = context->input(3); + // The shape of 'image' is [batch_size, image_height, image_width, depth]. + const Tensor& image = context->input(1); - OP_REQUIRES(context, grads.dims() == 4, - errors::InvalidArgument("grads image must be 4-D", - grads.shape().DebugString())); - + // Validate input shapes. + OP_REQUIRES_ASYNC(context, grads.dims() == 4, + errors::InvalidArgument("grads image must be 4-D", + grads.shape().DebugString()), + done); const int crop_height = grads.dim_size(1); const int crop_width = grads.dim_size(2); const int depth = grads.dim_size(3); - OP_REQUIRES(context, crop_height > 0 && crop_width > 0, - errors::InvalidArgument("grads dimensions must be positive")); + OP_REQUIRES_ASYNC( + context, crop_height > 0 && crop_width > 0, + errors::InvalidArgument("grads dimensions must be positive"), done); - // The shape of 'image' is [batch, image_height, image_width, depth]. - const Tensor& image = context->input(1); - OP_REQUIRES(context, image.dims() == 4, - errors::InvalidArgument("input image must be 4-D", - image.shape().DebugString())); - - const int batch = image.dim_size(0); + OP_REQUIRES_ASYNC(context, image.dims() == 4, + errors::InvalidArgument("input image must be 4-D", + image.shape().DebugString()), + done); + const int batch_size = image.dim_size(0); const int image_height = image.dim_size(1); const int image_width = image.dim_size(2); - OP_REQUIRES(context, image_height > 0 && image_width > 0, - errors::InvalidArgument("image dimensions must be positive")); - OP_REQUIRES(context, image.dim_size(3) == depth, - errors::InvalidArgument("image, grads depth differ")); - - // The shape of 'boxes' is [num_boxes, 4]. - const Tensor& boxes = context->input(2); - - // The shape of 'box_ind' is [num_boxes]. - const Tensor& box_ind = context->input(3); + OP_REQUIRES_ASYNC( + context, image_height > 0 && image_width > 0, + errors::InvalidArgument("image dimensions must be positive"), done); + OP_REQUIRES_ASYNC(context, image.dim_size(3) == depth, + errors::InvalidArgument("image, grads depth differ"), + done); int num_boxes = 0; - ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes); + OP_REQUIRES_OK_ASYNC( + context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done); - OP_REQUIRES( + OP_REQUIRES_ASYNC( context, grads.dim_size(0) == num_boxes, - errors::InvalidArgument("boxes and grads have incompatible shape")); + errors::InvalidArgument("boxes and grads have incompatible shape"), + done); // Allocate output tensor. Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->allocate_output( - 0, TensorShape({num_boxes, 4}), &output)); + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_output(0, TensorShape({num_boxes, 4}), &output), + done); - typename TTypes::ConstTensor grads_data = - grads.tensor(); - typename TTypes::ConstTensor image_data = image.tensor(); - typename TTypes::ConstTensor boxes_data = - boxes.tensor(); - typename TTypes::ConstTensor box_ind_data = - box_ind.tensor(); - typename TTypes::Tensor output_data = output->tensor(); + auto compute_callback = [context, output]() { + const Tensor& grads = context->input(0); + const Tensor& image = context->input(1); + const Tensor& boxes = context->input(2); + const Tensor& box_index = context->input(3); + const bool status = functor::CropAndResizeBackpropBoxes()( + context->eigen_device(), grads.tensor(), + image.tensor(), boxes.tensor(), + box_index.tensor(), output->tensor()); + if (!status) { + context->SetStatus(errors::Internal( + "Failed launch CropAndResizeBackpropBoxes kernel.")); + } + }; - CheckValidBoxInd(context, box_ind_data, batch); - - bool status = functor::CropAndResizeBackpropBoxes()( - context->eigen_device(), grads_data, image_data, boxes_data, - box_ind_data, output_data); - if (!status) { - context->SetStatus( - errors::Internal("Failed launch CropAndResizeBackpropBoxesKernel.")); - } + RunIfBoxIndexIsValid(context, box_index.tensor(), + batch_size, std::move(compute_callback), + std::move(done)); } }; @@ -507,9 +566,9 @@ struct CropAndResizeBackpropBoxes { typename TTypes::ConstTensor grads, typename TTypes::ConstTensor image, typename TTypes::ConstTensor boxes, - typename TTypes::ConstTensor box_ind, + typename TTypes::ConstTensor box_index, typename TTypes::Tensor grads_boxes) { - const int batch = image.dimension(0); + const int batch_size = image.dimension(0); const int image_height = image.dimension(1); const int image_width = image.dimension(2); @@ -526,8 +585,8 @@ struct CropAndResizeBackpropBoxes { const float y2 = boxes(b, 2); const float x2 = boxes(b, 3); - const int32 b_in = box_ind(b); - if (b_in < 0 || b_in >= batch) { + const int32 b_in = box_index(b); + if (!FastBoundsCheck(b_in, batch_size)) { continue; } @@ -609,30 +668,19 @@ struct CropAndResizeBackpropBoxes { return true; } }; + } // namespace functor -// Specialization of CheckValidBoxInd for a CPUDevice. -template <> -inline void CheckValidBoxInd( - OpKernelContext* context, typename TTypes::ConstTensor box_ind, - int batch) { - const int num_boxes = box_ind.dimension(0); - for (int b = 0; b < num_boxes; ++b) { - OP_REQUIRES(context, box_ind(b) >= 0 && box_ind(b) < batch, - errors::OutOfRange("box_ind has values outside [0, batch)")); - } -} - -#define REGISTER_KERNEL(T) \ - REGISTER_KERNEL_BUILDER(Name("CropAndResize") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .HostMemory("crop_size"), \ - CropAndResizeOp); \ - \ - REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T"), \ +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("CropAndResize") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("crop_size"), \ + CropAndResizeOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ CropAndResizeGradBoxesOp); TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL); @@ -654,50 +702,93 @@ TF_CALL_double(REGISTER_KERNEL); #if GOOGLE_CUDA -// Forward declaration of the CheckValidBoxIndHelper specialization for GPU. +// Forward declaration of the CheckValidBoxIndexHelper specialization for GPU. namespace functor { template <> -void CheckValidBoxIndHelper::operator()( - const GPUDevice& d, typename TTypes::ConstTensor box_ind, - int batch, typename TTypes::Tensor isvalid); -extern template struct CheckValidBoxIndHelper; +void CheckValidBoxIndexHelper::operator()( + const GPUDevice& d, typename TTypes::ConstTensor box_index, + int batch_size, typename TTypes::Tensor isvalid); +extern template struct CheckValidBoxIndexHelper; } // namespace functor -// Specialization of CheckValidBoxInd for a GPUDevice. +namespace { + +// Specialization of CheckValidBoxIndex for a GPUDevice. template <> -inline void CheckValidBoxInd( - OpKernelContext* context, typename TTypes::ConstTensor box_ind, - int batch) { - const int num_boxes = box_ind.dimension(0); +inline void RunIfBoxIndexIsValid( + OpKernelContext* context, typename TTypes::ConstTensor box_index, + int batch_size, const Callback& compute, const Callback& done) { + const int num_boxes = box_index.dimension(0); if (num_boxes == 0) { + compute(); + done(); return; } - Tensor isvalid_tensor; - OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum::value, - TensorShape({}), &isvalid_tensor)); - typename TTypes::Tensor isvalid = isvalid_tensor.tensor(); + Tensor isvalid_dev_tensor; + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_temp(DataTypeToEnum::value, TensorShape({}), + &isvalid_dev_tensor), + done); + typename TTypes::Tensor isvalid_dev = + isvalid_dev_tensor.tensor(); - functor::CheckValidBoxIndHelper()( - context->eigen_device(), box_ind, batch, isvalid); + // Run the actual box check on the device. + functor::CheckValidBoxIndexHelper()( + context->eigen_device(), box_index, batch_size, isvalid_dev); + // Copy the result back to the host. auto* stream = context->op_device_context()->stream(); - OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + OP_REQUIRES_ASYNC(context, stream, + errors::Internal("No GPU stream available."), done); + Tensor isvalid_host_tensor; + // Use pinned host memory on the host to avoid unnecessary + // synchronization. + AllocatorAttributes alloc_attr; + alloc_attr.set_on_host(true); + alloc_attr.set_gpu_compatible(true); + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_temp(DataTypeToEnum::value, TensorShape({}), + &isvalid_host_tensor, alloc_attr), + done); + perftools::gputools::DeviceMemoryBase wrapped(isvalid_dev.data(), + sizeof(bool)); + const bool status = + stream + ->ThenMemcpy( + isvalid_host_tensor.scalar().data() /* destination */, + wrapped /* source */, sizeof(bool)) + .ok(); + OP_REQUIRES_ASYNC( + context, status, + errors::Internal("Failed to launch copy of isvalid from device to host."), + done); - bool isvalid_host = false; - perftools::gputools::DeviceMemoryBase isvalid_gpu(isvalid.data(), - sizeof(bool)); - stream->ThenMemcpy(&isvalid_host, isvalid_gpu, sizeof(bool)); - stream->BlockHostUntilDone(); + // We capture both temporary tensors to prevent them from being deallocated + // when ComputeAsync returns and before the closure runs. + TensorReference isvalid_dev_ref(isvalid_dev_tensor); + auto wrapped_callback = [context, isvalid_host_tensor, isvalid_dev_ref, + compute, done]() { + auto stream = context->op_device_context()->stream(); + ScopedActivateExecutorContext scoped_activation{stream->parent()}; + const bool isvalid = isvalid_host_tensor.scalar()(); + isvalid_dev_ref.Unref(); + OP_REQUIRES_ASYNC( + context, isvalid, + errors::OutOfRange("box_index has values outside [0, batch_size)"), + done); + compute(); + done(); + }; - OP_REQUIRES(context, stream->ok(), - errors::Internal("cudaMemcpy from device to host failed")); - - OP_REQUIRES(context, isvalid_host, - errors::OutOfRange("box_ind has values outside [0, batch)")); + context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( + stream, wrapped_callback); } +} // namespace + #define REGISTER_KERNEL(T) \ REGISTER_KERNEL_BUILDER(Name("CropAndResize") \ .Device(DEVICE_GPU) \ diff --git a/tensorflow/core/kernels/crop_and_resize_op.h b/tensorflow/core/kernels/crop_and_resize_op.h index 84d7a5e03b8..b6b1dbd7b0c 100644 --- a/tensorflow/core/kernels/crop_and_resize_op.h +++ b/tensorflow/core/kernels/crop_and_resize_op.h @@ -55,12 +55,12 @@ struct CropAndResizeBackpropBoxes { }; template -struct CheckValidBoxIndHelper { - // Checks if all values in box_ind are in [0, batch). +struct CheckValidBoxIndexHelper { + // Checks if all values in box_index are in [0, batch). void operator()(const Device& d, - typename TTypes::ConstTensor box_ind, int batch, + typename TTypes::ConstTensor box_index, int batch, typename TTypes::Tensor isvalid) { - isvalid.device(d) = ((box_ind >= 0) && (box_ind < batch)).all(); + isvalid.device(d) = ((box_index >= 0) && (box_index < batch)).all(); } }; diff --git a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc index 1726e4a8168..d12787d5244 100644 --- a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc +++ b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc @@ -442,7 +442,7 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); #undef DEFINE_GPU_SPECS -template struct CheckValidBoxIndHelper; +template struct CheckValidBoxIndexHelper; } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/crop_and_resize_op_test.cc b/tensorflow/core/kernels/crop_and_resize_op_test.cc index 1bf28d4d003..22c659b587b 100644 --- a/tensorflow/core/kernels/crop_and_resize_op_test.cc +++ b/tensorflow/core/kernels/crop_and_resize_op_test.cc @@ -251,7 +251,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) { Status s = RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE( - StringPiece(s.ToString()).contains("box_ind has incompatible shape")) + StringPiece(s.ToString()).contains("box_index has incompatible shape")) << s; } @@ -264,7 +264,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) { Status s = RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE(StringPiece(s.ToString()) - .contains("box_ind has values outside [0, batch)")) + .contains("box_index has values outside [0, batch_size)")) << s; } diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h index ac6119d8a21..0fd6450f982 100644 --- a/tensorflow/core/kernels/cuda_solvers.h +++ b/tensorflow/core/kernels/cuda_solvers.h @@ -313,6 +313,9 @@ class ScratchSpace { int64 size() const { return scratch_tensor_.NumElements(); } const string& debug_info() const { return debug_info_; } + Tensor& tensor() { return scratch_tensor_; } + const Tensor& tensor() const { return scratch_tensor_; } + // Returns true if this ScratchSpace is in host memory. bool on_host() const { return on_host_; } diff --git a/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc b/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc index b9e42b4d00d..af6c094d7ac 100644 --- a/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc +++ b/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc @@ -51,55 +51,57 @@ namespace { // Hacks around missing support for complex arithmetic in nvcc. template -__host__ __device__ inline Scalar Multiply(Scalar x, Scalar y) { +__device__ inline Scalar Multiply(Scalar x, Scalar y) { return x * y; } template <> -__host__ __device__ inline cuComplex Multiply(cuComplex x, cuComplex y) { +__device__ inline cuComplex Multiply(cuComplex x, cuComplex y) { return cuCmulf(x, y); } template <> -__host__ __device__ inline cuDoubleComplex Multiply(cuDoubleComplex x, - cuDoubleComplex y) { +__device__ inline cuDoubleComplex Multiply(cuDoubleComplex x, + cuDoubleComplex y) { return cuCmul(x, y); } template -__host__ __device__ inline Scalar Negate(Scalar x) { +__device__ inline Scalar Negate(Scalar x) { return -x; } template <> -__host__ __device__ inline cuComplex Negate(cuComplex x) { +__device__ inline cuComplex Negate(cuComplex x) { return make_cuComplex(-cuCrealf(x), -cuCimagf(x)); } template <> -__host__ __device__ inline cuDoubleComplex Negate(cuDoubleComplex x) { +__device__ inline cuDoubleComplex Negate(cuDoubleComplex x) { return make_cuDoubleComplex(-cuCreal(x), -cuCimag(x)); } template -__host__ __device__ inline bool IsFinite(Scalar x) { - return isfinite(x); +__device__ inline bool IsFinite(Scalar x) { + return Eigen::numext::isfinite(x); } template <> -__host__ __device__ inline bool IsFinite(cuComplex x) { - return isfinite(cuCrealf(x)) && isfinite(cuCimagf(x)); +__device__ inline bool IsFinite(cuComplex x) { + return Eigen::numext::isfinite(cuCrealf(x)) && + Eigen::numext::isfinite(cuCimagf(x)); } template <> -__host__ __device__ inline bool IsFinite(cuDoubleComplex x) { - return isfinite(cuCreal(x)) && isfinite(cuCimag(x)); +__device__ inline bool IsFinite(cuDoubleComplex x) { + return Eigen::numext::isfinite(cuCreal(x)) && + Eigen::numext::isfinite(cuCimag(x)); } template struct Const { template - __host__ __device__ static inline Scalar make_const(const RealScalar x) { + __device__ static inline Scalar make_const(const RealScalar x) { return Scalar(x); } }; @@ -107,7 +109,7 @@ struct Const { template <> struct Const { template - __host__ __device__ static inline cuComplex make_const(const RealScalar x) { + __device__ static inline cuComplex make_const(const RealScalar x) { return make_cuComplex(x, 0.0f); } }; @@ -115,8 +117,7 @@ struct Const { template <> struct Const { template - __host__ __device__ static inline cuDoubleComplex make_const( - const RealScalar x) { + __device__ static inline cuDoubleComplex make_const(const RealScalar x) { return make_cuDoubleComplex(x, 0.0f); } }; diff --git a/tensorflow/core/kernels/dataset_utils.cc b/tensorflow/core/kernels/dataset_utils.cc new file mode 100644 index 00000000000..f320b3b09c6 --- /dev/null +++ b/tensorflow/core/kernels/dataset_utils.cc @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/dataset_utils.h" + +namespace tensorflow { + +namespace dataset { + +Status MakeIteratorFromInputElement( + IteratorContext* ctx, const std::vector& input_element, + int64 thread_index, CapturedFunction* captured_func, StringPiece prefix, + std::unique_ptr* out_iterator) { + FunctionLibraryRuntime::Options opts; + opts.runner = ctx->runner(); + // Choose a step ID that is guaranteed not to clash with any + // Session-generated step ID. DirectSession only generates + // non-negative step IDs (contiguous, starting from 0), and + // MasterSession generates 56-bit random step IDs whose MSB + // is always 0, so a negative random step ID should suffice. + opts.step_id = CapturedFunction::generate_step_id(); + ScopedStepContainer step_container( + opts.step_id, [captured_func, ctx](const string& name) { + captured_func->resource_manager()->Cleanup(name).IgnoreError(); + }); + opts.step_container = &step_container; + std::vector return_values; + TF_RETURN_IF_ERROR(captured_func->Run(opts, input_element, &return_values)); + + if (!(return_values.size() == 1 && return_values[0].dtype() == DT_RESOURCE && + TensorShapeUtils::IsScalar(return_values[0].shape()))) { + return errors::InvalidArgument( + "Function must return a single scalar of dtype DT_RESOURCE."); + } + + // Retrieve the dataset that was created in `f`. + DatasetBase* returned_dataset; + const ResourceHandle& dataset_resource = + return_values[0].scalar()(); + + // NOTE(mrry): We cannot use the core `LookupResource()` or + // `DeleteResource()` functions, because we have an + // `IteratorContext*` and not an `OpKernelContext*`, so we + // replicate the necessary functionality here. + auto type_index = MakeTypeIndex(); + if (type_index.hash_code() != dataset_resource.hash_code()) { + return errors::InvalidArgument("Function must return a Dataset resource."); + } + TF_RETURN_IF_ERROR(captured_func->resource_manager()->Lookup( + dataset_resource.container(), dataset_resource.name(), + &returned_dataset)); + core::ScopedUnref unref_dataset(returned_dataset); + + // Create an iterator for the dataset that was returned by + // `f`. This transfers ownership of the dataset to the + // iterator, so we can delete it from the resource manager. + *out_iterator = returned_dataset->MakeIterator( + strings::StrCat(prefix, "[", thread_index, "]")); + TF_RETURN_IF_ERROR(captured_func->resource_manager()->Delete( + dataset_resource.container(), dataset_resource.name())); + return Status::OK(); +} + +} // namespace dataset + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/dataset_utils.h b/tensorflow/core/kernels/dataset_utils.h new file mode 100644 index 00000000000..eea2b8802b8 --- /dev/null +++ b/tensorflow/core/kernels/dataset_utils.h @@ -0,0 +1,35 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATASET_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATASET_UTILS_H_ + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/captured_function.h" +#include "tensorflow/core/kernels/dataset.h" + +namespace tensorflow { + +namespace dataset { + +Status MakeIteratorFromInputElement( + IteratorContext* ctx, const std::vector& input_element, + int64 thread_index, CapturedFunction* captured_func, StringPiece prefix, + std::unique_ptr* out_iterator); + +} // namespace dataset + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATASET_UTILS_H_ diff --git a/tensorflow/core/kernels/flat_map_dataset_op.cc b/tensorflow/core/kernels/flat_map_dataset_op.cc index e2310fecc76..a87e54bf310 100644 --- a/tensorflow/core/kernels/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/flat_map_dataset_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/kernels/captured_function.h" +#include "tensorflow/core/kernels/dataset_utils.h" namespace tensorflow { @@ -125,58 +126,9 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - FunctionLibraryRuntime::Options opts; - opts.runner = ctx->runner(); - opts.step_id = CapturedFunction::generate_step_id(); - ScopedStepContainer step_container( - opts.step_id, [this, ctx](const string& name) { - dataset() - ->captured_func_->resource_manager() - ->Cleanup(name) - .IgnoreError(); - }); - opts.step_container = &step_container; - std::vector return_values; - TF_RETURN_IF_ERROR( - dataset()->captured_func_->Run(opts, args, &return_values)); - - if (!(return_values.size() == 1 && - return_values[0].dtype() == DT_RESOURCE && - TensorShapeUtils::IsScalar(return_values[0].shape()))) { - return errors::InvalidArgument( - "`f` must return a single scalar of dtype DT_RESOURCE."); - } - - // Retrieve the dataset that was created in `f`. - DatasetBase* returned_dataset; - const ResourceHandle& dataset_resource = - return_values[0].scalar()(); - - // NOTE(mrry): We cannot use the core `LookupResource()` or - // `DeleteResource()` functions, because we have an - // `IteratorContext*` and not an `OpKernelContext*`, so we - // replicate the necessary functionality here. - auto type_index = MakeTypeIndex(); - if (type_index.hash_code() != dataset_resource.hash_code()) { - return errors::InvalidArgument( - "`f` must return a Dataset resource."); - } - TF_RETURN_IF_ERROR( - dataset()->captured_func_->resource_manager()->Lookup( - dataset_resource.container(), dataset_resource.name(), - &returned_dataset)); - core::ScopedUnref unref_dataset(returned_dataset); - - // Create an iterator for the dataset that was returned by - // `f`. This transfers ownership of the dataset to the - // iterator, so we can delete it from the resource manager. - current_element_iterator_ = returned_dataset->MakeIterator( - strings::StrCat(prefix(), "[", element_index_++, "]")); - TF_RETURN_IF_ERROR( - dataset() - ->captured_func_->resource_manager() - ->Delete(dataset_resource.container(), - dataset_resource.name())); + TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement( + ctx, args, element_index_++, dataset()->captured_func_.get(), + prefix(), ¤t_element_iterator_)); } while (true); } diff --git a/tensorflow/core/kernels/interleave_dataset_op.cc b/tensorflow/core/kernels/interleave_dataset_op.cc index dce4f881017..7b148b74c98 100644 --- a/tensorflow/core/kernels/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/interleave_dataset_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/kernels/captured_function.h" +#include "tensorflow/core/kernels/dataset_utils.h" namespace tensorflow { @@ -168,8 +169,9 @@ class InterleaveDatasetOp : public OpKernel { TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, &args, &end_of_input_)); if (!end_of_input_) { - TF_RETURN_IF_ERROR(MakeIteratorFromInputElement( - ctx, args, ¤t_elements_[cycle_index_])); + TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement( + ctx, args, cycle_index_, dataset()->captured_func_.get(), + prefix(), ¤t_elements_[cycle_index_])); ++num_open_; } } else { @@ -182,62 +184,6 @@ class InterleaveDatasetOp : public OpKernel { } private: - Status MakeIteratorFromInputElement( - IteratorContext* ctx, const std::vector& input_element, - std::unique_ptr* out_iterator) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - FunctionLibraryRuntime::Options opts; - opts.runner = ctx->runner(); - opts.step_id = CapturedFunction::generate_step_id(); - ScopedStepContainer step_container( - opts.step_id, [this, ctx](const string& name) { - dataset() - ->captured_func_->resource_manager() - ->Cleanup(name) - .IgnoreError(); - }); - opts.step_container = &step_container; - std::vector return_values; - TF_RETURN_IF_ERROR(dataset()->captured_func_->Run(opts, input_element, - &return_values)); - - if (!(return_values.size() == 1 && - return_values[0].dtype() == DT_RESOURCE && - TensorShapeUtils::IsScalar(return_values[0].shape()))) { - return errors::InvalidArgument( - "`f` must return a single scalar of dtype DT_RESOURCE."); - } - - // Retrieve the dataset that was created in `f`. - DatasetBase* returned_dataset; - const ResourceHandle& dataset_resource = - return_values[0].scalar()(); - - // NOTE(mrry): We cannot use the core `LookupResource()` or - // `DeleteResource()` functions, because we have an - // `IteratorContext*` and not an `OpKernelContext*`, so we - // replicate the necessary functionality here. - auto type_index = MakeTypeIndex(); - if (type_index.hash_code() != dataset_resource.hash_code()) { - return errors::InvalidArgument("`f` must return a Dataset resource."); - } - TF_RETURN_IF_ERROR( - dataset()->captured_func_->resource_manager()->Lookup( - dataset_resource.container(), dataset_resource.name(), - &returned_dataset)); - core::ScopedUnref unref_dataset(returned_dataset); - - // Create an iterator for the dataset that was returned by - // `f`. This transfers ownership of the dataset to the - // iterator, so we can delete it from the resource manager. - *out_iterator = returned_dataset->MakeIterator( - strings::StrCat(prefix(), "[", cycle_index_, "]")); - TF_RETURN_IF_ERROR( - dataset()->captured_func_->resource_manager()->Delete( - dataset_resource.container(), dataset_resource.name())); - return Status::OK(); - } - mutex mu_; const std::unique_ptr input_impl_ GUARDED_BY(mu_); std::vector> current_elements_ diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index 50700c8bc8a..00884d09814 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -98,11 +98,11 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { "Conv2DCustomBackpropInput: size must be 4-dim")); const int64* filter_sizes = - (const int64*) mkl_context.filter_shape.GetSizes(); + (const int64*)mkl_context.filter_shape.GetSizes(); const int64 filter_dims = mkl_context.filter_shape.GetDimension(); - OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(filter_sizes, - filter_dims, &filter_shape)); + OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( + filter_sizes, filter_dims, &filter_shape)); } else { filter_shape = filter.shape(); } diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index b50a6343ba9..7099aa13071 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -270,22 +270,22 @@ class MklConv2DOp : public OpKernel { MklShape mkl_filter_output_mkl_shape; mkl_filter_output_mkl_shape.SetMklTensor(true); mkl_filter_output_mkl_shape.SetMklLayout(mkl_context.prim_fwd, - dnnResourceFilter); + dnnResourceFilter); size_t filter_sizes[4] = {filter.dim_size(0), filter.dim_size(1), - filter.dim_size(2), filter.dim_size(3)}; + filter.dim_size(2), filter.dim_size(3)}; mkl_filter_output_mkl_shape.SetTfLayout(filter.dims(), filter_sizes, - mkl_context.filter_strides); + mkl_context.filter_strides); mkl_filter_output_mkl_shape.SetTfDimOrder(mkl_context.filter_dims, - data_format_); + data_format_); mkl_filter_output_tf_shape.AddDim( - dnnLayoutGetMemorySize_F32( - static_cast( - mkl_filter_output_mkl_shape.GetMklLayout())) / - sizeof(T)); + dnnLayoutGetMemorySize_F32(static_cast( + mkl_filter_output_mkl_shape.GetMklLayout())) / + sizeof(T)); AllocateOutputSetMklShape(context, 1, &mkl_context.output_filter, - mkl_filter_output_tf_shape, mkl_filter_output_mkl_shape); + mkl_filter_output_tf_shape, + mkl_filter_output_mkl_shape); mkl_context.conv_res[dnnResourceDst] = static_cast(output->flat().data()); @@ -406,8 +406,13 @@ class MklConv2DOp : public OpKernel { CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_filter, lt_filter, mkl_lt_internal_filter), E_SUCCESS); +<<<<<<< HEAD mkl_buf_convert_filter = const_cast(static_cast( output_filter->flat().data())); +======= + mkl_buf_convert_filter = const_cast( + static_cast(output_filter->flat().data())); +>>>>>>> e722358e7e96dd2aa20d7e2c56336e76845daa6a CHECK_EQ( dnnConversionExecute_F32(mkl_prim_convert_filter, mkl_buf_filter, mkl_buf_convert_filter), diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc index 03c3fb09a1d..5e985824750 100644 --- a/tensorflow/core/kernels/mkl_reshape_op.cc +++ b/tensorflow/core/kernels/mkl_reshape_op.cc @@ -128,6 +128,7 @@ class MklReshapeOp : public OpKernel { CopyTfTensorInToOutWithShape(context, 0, 0, shape); } } + private: template Status ValidateSizes(const Tensor& sizes, int64* product, int* unknown_index, diff --git a/tensorflow/core/kernels/parse_tensor_op.cc b/tensorflow/core/kernels/parse_tensor_op.cc index dd645262d2e..ab91a6ef677 100644 --- a/tensorflow/core/kernels/parse_tensor_op.cc +++ b/tensorflow/core/kernels/parse_tensor_op.cc @@ -16,6 +16,7 @@ limitations under the License. // See docs in ../ops/parsing_ops.cc. #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -66,7 +67,6 @@ class ParseTensorOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("ParseTensor").Device(DEVICE_CPU), ParseTensorOp); - template class SerializeTensorOp : public OpKernel { public: @@ -81,14 +81,14 @@ class SerializeTensorOp : public OpKernel { tensor.AsProtoTensorContent(&proto); } Tensor* proto_string = nullptr; - OP_REQUIRES_OK( - context, context->allocate_output(0, TensorShape({}), &proto_string)); + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({}), &proto_string)); CHECK(proto.SerializeToString(&proto_string->scalar()())); } }; -#define REGISTER(T) \ - REGISTER_KERNEL_BUILDER( \ +#define REGISTER(T) \ + REGISTER_KERNEL_BUILDER( \ Name("SerializeTensor").Device(DEVICE_CPU).TypeConstraint("T"), \ SerializeTensorOp); TF_CALL_ALL_TYPES(REGISTER) diff --git a/tensorflow/core/kernels/parse_tensor_test.cc b/tensorflow/core/kernels/parse_tensor_test.cc index f6f60fee71c..4a5fc07935c 100644 --- a/tensorflow/core/kernels/parse_tensor_test.cc +++ b/tensorflow/core/kernels/parse_tensor_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include -#include #include +#include #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" @@ -33,27 +33,23 @@ namespace { class SerializeTensorOpTest : public OpsTestBase { protected: template - void MakeOp(const TensorShape& input_shape, - std::function functor) { - TF_ASSERT_OK( - NodeDefBuilder("myop", "SerializeTensor") - .Input(FakeInput(DataTypeToEnum::value)) - .Finalize(node_def())); + void MakeOp(const TensorShape& input_shape, std::function functor) { + TF_ASSERT_OK(NodeDefBuilder("myop", "SerializeTensor") + .Input(FakeInput(DataTypeToEnum::value)) + .Finalize(node_def())); TF_ASSERT_OK(InitOp()); AddInput(input_shape, functor); } void ParseSerializedWithNodeDef(const NodeDef& parse_node_def, - Tensor* serialized, - Tensor* parse_output) { + Tensor* serialized, Tensor* parse_output) { std::unique_ptr device( DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); gtl::InlinedVector inputs; inputs.push_back({nullptr, serialized}); Status status; - std::unique_ptr op( - CreateOpKernel(DEVICE_CPU, device.get(), - cpu_allocator(), parse_node_def, - TF_GRAPH_DEF_VERSION, &status)); + std::unique_ptr op(CreateOpKernel(DEVICE_CPU, device.get(), + cpu_allocator(), parse_node_def, + TF_GRAPH_DEF_VERSION, &status)); TF_EXPECT_OK(status); OpKernelContext::Params params; params.device = device.get(); @@ -80,8 +76,8 @@ class SerializeTensorOpTest : public OpsTestBase { TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_half) { MakeOp(TensorShape({10}), [](int x) -> Eigen::half { - return static_cast(x / 10.); - }); + return static_cast(x / 10.); + }); TF_ASSERT_OK(RunOpKernel()); Tensor parse_output; ParseSerializedOutput(GetOutput(0), &parse_output); @@ -89,9 +85,8 @@ TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_half) { } TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_float) { - MakeOp(TensorShape({1, 10}), [](int x) -> float { - return static_cast(x / 10.); - }); + MakeOp(TensorShape({1, 10}), + [](int x) -> float { return static_cast(x / 10.); }); TF_ASSERT_OK(RunOpKernel()); Tensor parse_output; ParseSerializedOutput(GetOutput(0), &parse_output); @@ -99,9 +94,8 @@ TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_float) { } TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_double) { - MakeOp(TensorShape({5, 5}), [](int x) -> double { - return static_cast(x / 10.); - }); + MakeOp(TensorShape({5, 5}), + [](int x) -> double { return static_cast(x / 10.); }); TF_ASSERT_OK(RunOpKernel()); Tensor parse_output; ParseSerializedOutput(GetOutput(0), &parse_output); @@ -109,9 +103,8 @@ TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_double) { } TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int64) { - MakeOp(TensorShape({2, 3, 4}), [](int x) -> int64 { - return static_cast(x - 10); - }); + MakeOp(TensorShape({2, 3, 4}), + [](int x) -> int64 { return static_cast(x - 10); }); TF_ASSERT_OK(RunOpKernel()); Tensor parse_output; ParseSerializedOutput(GetOutput(0), &parse_output); @@ -119,9 +112,8 @@ TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int64) { } TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int32) { - MakeOp(TensorShape({4, 2}), [](int x) -> int32 { - return static_cast(x + 7); - }); + MakeOp(TensorShape({4, 2}), + [](int x) -> int32 { return static_cast(x + 7); }); TF_ASSERT_OK(RunOpKernel()); Tensor parse_output; ParseSerializedOutput(GetOutput(0), &parse_output); @@ -129,9 +121,8 @@ TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int32) { } TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int16) { - MakeOp(TensorShape({8}), [](int x) -> int16 { - return static_cast(x + 18); - }); + MakeOp(TensorShape({8}), + [](int x) -> int16 { return static_cast(x + 18); }); TF_ASSERT_OK(RunOpKernel()); Tensor parse_output; ParseSerializedOutput(GetOutput(0), &parse_output); @@ -139,9 +130,8 @@ TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int16) { } TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int8) { - MakeOp(TensorShape({2}), [](int x) -> int8 { - return static_cast(x + 8); - }); + MakeOp(TensorShape({2}), + [](int x) -> int8 { return static_cast(x + 8); }); TF_ASSERT_OK(RunOpKernel()); Tensor parse_output; ParseSerializedOutput(GetOutput(0), &parse_output); @@ -149,9 +139,8 @@ TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int8) { } TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_uint16) { - MakeOp(TensorShape({1, 3}), [](int x) -> uint16 { - return static_cast(x + 2); - }); + MakeOp(TensorShape({1, 3}), + [](int x) -> uint16 { return static_cast(x + 2); }); TF_ASSERT_OK(RunOpKernel()); Tensor parse_output; ParseSerializedOutput(GetOutput(0), &parse_output); @@ -159,9 +148,8 @@ TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_uint16) { } TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_uint8) { - MakeOp(TensorShape({2, 1, 1}), [](int x) -> uint8 { - return static_cast(x + 1); - }); + MakeOp(TensorShape({2, 1, 1}), + [](int x) -> uint8 { return static_cast(x + 1); }); TF_ASSERT_OK(RunOpKernel()); Tensor parse_output; ParseSerializedOutput(GetOutput(0), &parse_output); @@ -170,9 +158,8 @@ TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_uint8) { TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_complex64) { MakeOp(TensorShape({}), [](int x) -> complex64 { - return complex64{ static_cast(x / 8.), - static_cast(x / 2.) }; - }); + return complex64{static_cast(x / 8.), static_cast(x / 2.)}; + }); TF_ASSERT_OK(RunOpKernel()); Tensor parse_output; ParseSerializedOutput(GetOutput(0), &parse_output); @@ -181,8 +168,8 @@ TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_complex64) { TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_complex128) { MakeOp(TensorShape({3}), [](int x) -> complex128 { - return complex128{ x / 3., x / 2. }; - }); + return complex128{x / 3., x / 2.}; + }); TF_ASSERT_OK(RunOpKernel()); Tensor parse_output; ParseSerializedOutput(GetOutput(0), &parse_output); @@ -190,9 +177,8 @@ TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_complex128) { } TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_bool) { - MakeOp(TensorShape({1}), [](int x) -> bool { - return static_cast(x % 2); - }); + MakeOp(TensorShape({1}), + [](int x) -> bool { return static_cast(x % 2); }); TF_ASSERT_OK(RunOpKernel()); Tensor parse_output; ParseSerializedOutput(GetOutput(0), &parse_output); @@ -200,13 +186,12 @@ TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_bool) { } TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_string) { - MakeOp(TensorShape({10}), [](int x) -> std::string { - return std::to_string(x / 10.); - }); + MakeOp(TensorShape({10}), + [](int x) -> string { return std::to_string(x / 10.); }); TF_ASSERT_OK(RunOpKernel()); Tensor parse_output; - ParseSerializedOutput(GetOutput(0), &parse_output); - test::ExpectTensorEqual(parse_output, GetInput(0)); + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); } } // namespace diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc index 8f7eff113cd..5624d5cd1b1 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.cc +++ b/tensorflow/core/kernels/segment_reduction_ops.cc @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/util.h" - #if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/kernels/cuda_solvers.h" @@ -249,10 +248,11 @@ class SegmentSumGPUOp : public AsyncOpKernel { auto stream = context->op_device_context()->stream(); OP_REQUIRES_ASYNC( - context, stream - ->ThenMemcpy(output_rows_host.mutable_data(), - output_rows_device, sizeof(Index)) - .ok(), + context, + stream + ->ThenMemcpy(output_rows_host.mutable_data(), output_rows_device, + sizeof(Index)) + .ok(), errors::Internal( "SegmentSumGPUOp: failed to copy output_rows from device"), done); diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc index 26fcafee34a..159fada621b 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc @@ -186,10 +186,10 @@ void SegmentSumFunctor::operator()( input_inner_dim_size * input_outer_dim_num_stripe; config = GetCudaLaunchConfig(total_stripe_count, d); - SortedSegmentSumCustomKernel<<< - config.block_count, config.thread_per_block, 0, d.stream()>>>( - input_outer_dim_size, input_inner_dim_size, output_rows, - segment_ids.data(), data, output.data(), total_stripe_count); + SortedSegmentSumCustomKernel + <<>>( + input_outer_dim_size, input_inner_dim_size, output_rows, + segment_ids.data(), data, output.data(), total_stripe_count); }; // UnsortedSegmentSumFunctor implementation for GPUDevice. diff --git a/tensorflow/core/kernels/sloppy_interleave_dataset_op.cc b/tensorflow/core/kernels/sloppy_interleave_dataset_op.cc new file mode 100644 index 00000000000..d95f51f0f21 --- /dev/null +++ b/tensorflow/core/kernels/sloppy_interleave_dataset_op.cc @@ -0,0 +1,370 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/dataset.h" + +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/dataset_utils.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/random/random.h" + +#include "tensorflow/core/kernels/captured_function.h" + +namespace tensorflow { + +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. + +class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel { + public: + explicit SloppyInterleaveDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx), + graph_def_version_(ctx->graph_def_version()) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + OpInputList inputs; + OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); + std::vector other_arguments; + other_arguments.reserve(inputs.size()); + for (const Tensor& t : inputs) { + other_arguments.push_back(t); + } + + int64 cycle_length; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "cycle_length", &cycle_length)); + OP_REQUIRES(ctx, cycle_length > 0, + errors::InvalidArgument("`cycle_length` must be > 0")); + + int64 block_length; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "block_length", &block_length)); + OP_REQUIRES(ctx, block_length > 0, + errors::InvalidArgument("`block_length` must be > 0")); + + std::unique_ptr captured_func; + OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, func_, graph_def_version_, + std::move(other_arguments), + &captured_func)); + + *output = new Dataset(input, std::move(captured_func), cycle_length, + block_length, output_types_, output_shapes_); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(const DatasetBase* input, + std::unique_ptr captured_func, int64 cycle_length, + int64 block_length, const DataTypeVector& output_types, + const std::vector& output_shapes) + : input_(input), + captured_func_(std::move(captured_func)), + cycle_length_(cycle_length), + block_length_(block_length), + output_types_(output_types), + output_shapes_(output_shapes) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr MakeIterator( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::SloppyInterleave")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() override { + return "SloppyInterleaveDatasetOp::Dataset"; + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params), + input_impl_(params.dataset->input_->MakeIterator(params.prefix)), + output_elements_(params.dataset->cycle_length_) {} + + ~Iterator() override { + mutex_lock l(mu_); + cancelled_ = true; + // Notify all workers in case they are blocked. + for (int64 i = 0; i < dataset()->cycle_length_; ++i) { + output_elements_[i].cond_var.notify_all(); + } + } + + // It is implemented so that it matches the deterministic interleave + // unless we would block waiting for an element, at which point it skips + // along to the next available value. + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx)); + // Search for available items, blocking if necessary. + while (!cancelled_) { + for (size_t i = 0; i < dataset()->cycle_length_; ++i) { + size_t index = (next_index_ + i) % dataset()->cycle_length_; + if (output_elements_[index].is_produced) { + next_index_ = index; + if (i == 0) { + block_count_++; + if (block_count_ == dataset()->block_length_) { + next_index_ = (index + 1) % dataset()->cycle_length_; + block_count_ = 0; + } + } else { + block_count_ = 0; + } + // If we encounter an EoF, advance to the next iterator + if (output_elements_[index].end_of_sequence) { + output_elements_[index].is_produced = false; + output_elements_[index].cond_var.notify_one(); + next_index_ = (index + 1) % dataset()->cycle_length_; + block_count_ = 0; + i = -1; // Restart the inner loop + continue; + } + *end_of_sequence = false; + if (output_elements_[index].output_status.ok()) { + output_elements_[index].output_value.swap(*out_tensors); + } + output_elements_[index].is_produced = false; + output_elements_[index].cond_var.notify_one(); + return output_elements_[index].output_status; + } + } + + if (num_active_threads_ == 0) { + // No potential for future values. + // + // Note: this condition check must occur after checking the output + // buffer, as its possible for there to be values in the output + // buffer, even if the number of live threads is zero. + *end_of_sequence = true; + return Status::OK(); + } + // No values available; wait until woken up. + cond_var_.wait(l); + } + return errors::Cancelled( + "SloppyInterleaveDatasetOp::Dataset::Iterator::GetNext"); + } + + private: + // Internal structure to manage thread coordination. All values are + // guarded by the enclosing Iterator's mu_. + struct OutputBufferElement { + // The producer must set `is_produced` to `true` after + // `output_status` or `output_value` has been written. + bool is_produced = false; + // The producer sets `output_status` if either getting the input element + // or applying the function to it fails. + Status output_status; + // Reached end of sequence for the underlying iterator. + bool end_of_sequence = false; + // The output data element. + std::vector output_value; + // The producer thread waits on this condition variable after having + // produced an element. The reader thread notifies this condition + // variable after reading the value. + condition_variable cond_var; + }; + + Status EnsureWorkerThreadsStarted(IteratorContext* ctx) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (worker_threads_.empty()) { + for (int64 i = 0; i < dataset()->cycle_length_; ++i) { + // Serialize the creation of the workers and their corresponding + // input elements to ensure we match the standard interleave when + // the underlying iterators induce no delay. + std::vector args; + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, &args, &end_of_input_)); + if (end_of_input_) { + LOG(WARNING) << "Input iterator exhausted after " << i + << " elements; cannot start all " + << dataset()->cycle_length_ << " worker threads."; + return Status::OK(); + } + std::unique_ptr itr; + TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement( + ctx, args, i, dataset()->captured_func_.get(), prefix(), &itr)); + worker_threads_.emplace_back( + std::unique_ptr(ctx->env()->StartThread( + {}, "worker_thread", + std::bind(&Iterator::WorkerThread, this, + new IteratorContext(*ctx), i, itr.release())))); + num_active_threads_ = i + 1; + } + } + return Status::OK(); + } + + void BlockAndUpdateOutputBuffer(mutex_lock* l, const int64 thread_index, + const Status& status, + bool end_of_sequence, + std::vector* out_tensors) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + // We have produced an element; push it into the output buffer + // when space is available. + while (!cancelled_ && output_elements_[thread_index].is_produced) { + output_elements_[thread_index].cond_var.wait(*l); + } + if (cancelled_) { + return; + } + output_elements_[thread_index].is_produced = true; + output_elements_[thread_index].output_status = status; + output_elements_[thread_index].end_of_sequence = end_of_sequence; + if (status.ok()) { + output_elements_[thread_index].output_value.swap(*out_tensors); + } else { + output_elements_[thread_index].output_value.clear(); + } + cond_var_.notify_one(); + } + + // Races to produce elements into the output queue buffers. + void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index, + IteratorBase* out_iterator_ptr) { + // std::function arguments are copy-constructable, so we pass raw + // pointers, and then immediately wrap them to ensure correct ownership. + std::unique_ptr ctx(ctx_ptr); + std::unique_ptr out_iterator(out_iterator_ptr); + auto cleanup = gtl::MakeCleanup([this, thread_index] { + mutex_lock l(mu_); + num_active_threads_--; + cond_var_.notify_all(); + }); + while (true) { + // Attempt to produce an element. + bool end_of_out_itr_input = false; + std::vector out_tensors; + Status element_status = out_iterator->GetNext(ctx.get(), &out_tensors, + &end_of_out_itr_input); + // Handle output. + { + mutex_lock l(mu_); + BlockAndUpdateOutputBuffer(&l, thread_index, element_status, + end_of_out_itr_input, &out_tensors); + if (end_of_out_itr_input) { + // We have exhausted our current iterator; get a new iterator; + // loop to handle errors. + while (!cancelled_) { + if (end_of_input_) { + // No more iterator inputs; we're done! + return; + } + std::vector args; + // BlockAndUpdateOutputBuffer() sequences calls to + // input_impl_->GetNext when the out_iterator doesn't cause + // slopping. + Status input_status = + input_impl_->GetNext(ctx.get(), &args, &end_of_input_); + if (end_of_input_) { + // No more elements to produce, stop the worker thread. + return; + } + if (input_status.ok()) { + input_status = dataset::MakeIteratorFromInputElement( + ctx.get(), args, thread_index, + dataset()->captured_func_.get(), prefix(), &out_iterator); + } + if (input_status.ok()) { + // Successfully have a new out_iterator; restart the outer + // loop to produce an element. + break; + } + + // We encountered an error; push the error to the output buffer. + BlockAndUpdateOutputBuffer(&l, thread_index, input_status, + /* end_of_sequence = */ false, + &out_tensors); + } + } + + // Check if we should exit. + if (cancelled_) { + return; + } + } + } + } + + // Mutex & condition variable to guard mutable iterator internals and + // coordinate among worker threads and client thread[s]. + mutex mu_; + condition_variable cond_var_; + // The iterator producing elements which are converted to datasets by + // the dataset()->captured_func_ then interleaved together. + const std::unique_ptr input_impl_ GUARDED_BY(mu_); + // Whether the input_impl_ can produce future elements. + bool end_of_input_ GUARDED_BY(mu_) = false; + // The buffer of elements to be produced. Each worker thread operates + // on a single OutputBufferElement. + std::vector output_elements_ GUARDED_BY(mu_); + // The index into output_elements_ for next element to produce. + size_t next_index_ GUARDED_BY(mu_) = 0; + // The number of items produced so far within the block + size_t block_count_ GUARDED_BY(mu_) = 0; + // Number of active threads. + size_t num_active_threads_ GUARDED_BY(mu_) = 0; + // Flag to instruct the worker threads to exit. + bool cancelled_ GUARDED_BY(mu_) = false; + // Pointers to the worker threads. This must be last to ensure the + // threads have exited before any other members are deallocated. + // TODO(b/65178177): Avoid allocating additional threads. + std::vector> worker_threads_ GUARDED_BY(mu_); + }; + + const DatasetBase* const input_; + const std::unique_ptr captured_func_; + const int64 cycle_length_; + const int64 block_length_; + const DataTypeVector output_types_; + const std::vector output_shapes_; + }; + + const int graph_def_version_; + DataTypeVector output_types_; + std::vector output_shapes_; + const NameAttrList* func_; +}; + +REGISTER_KERNEL_BUILDER(Name("SloppyInterleaveDataset").Device(DEVICE_CPU), + SloppyInterleaveDatasetOp); + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc index d0eca0f1e7f..cfa707de715 100644 --- a/tensorflow/core/kernels/summary_kernels.cc +++ b/tensorflow/core/kernels/summary_kernels.cc @@ -40,12 +40,7 @@ class CreateSummaryFileWriterOp : public OpKernel { SummaryWriterInterface* s; OP_REQUIRES_OK(ctx, CreateSummaryWriter(max_queue, flush_millis, logdir, filename_suffix, ctx->env(), &s)); - Status status = CreateResource(ctx, HandleFromInput(ctx, 0), s); - if (!status.ok()) { - s->Unref(); - ctx->SetStatus(status); - return; - } + OP_REQUIRES_OK(ctx, CreateResource(ctx, HandleFromInput(ctx, 0), s)); } }; REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU), diff --git a/tensorflow/core/lib/io/buffered_inputstream.cc b/tensorflow/core/lib/io/buffered_inputstream.cc index 6f72da47131..b247e9c5756 100644 --- a/tensorflow/core/lib/io/buffered_inputstream.cc +++ b/tensorflow/core/lib/io/buffered_inputstream.cc @@ -41,9 +41,18 @@ BufferedInputStream::~BufferedInputStream() { } Status BufferedInputStream::FillBuffer() { + if (!file_status_.ok()) { + pos_ = 0; + limit_ = 0; + return file_status_; + } Status s = input_stream_->ReadNBytes(size_, &buf_); pos_ = 0; limit_ = buf_.size(); + if (buf_.empty()) { + DCHECK(!s.ok()); + file_status_ = s; + } return s; } @@ -82,6 +91,9 @@ Status BufferedInputStream::ReadNBytes(int64 bytes_to_read, string* result) { bytes_to_read); } result->clear(); + if (!file_status_.ok() && bytes_to_read > 0) { + return file_status_; + } result->reserve(bytes_to_read); Status s; @@ -91,6 +103,8 @@ Status BufferedInputStream::ReadNBytes(int64 bytes_to_read, string* result) { s = FillBuffer(); // If we didn't read any bytes, we're at the end of the file; break out. if (limit_ == 0) { + DCHECK(!s.ok()); + file_status_ = s; break; } } @@ -124,6 +138,9 @@ Status BufferedInputStream::SkipNBytes(int64 bytes_to_skip) { Status s = input_stream_->SkipNBytes(bytes_to_skip - (limit_ - pos_)); pos_ = 0; limit_ = 0; + if (errors::IsOutOfRange(s)) { + file_status_ = s; + } return s; } return Status::OK(); @@ -163,6 +180,7 @@ Status BufferedInputStream::ReadAll(string* result) { } if (errors::IsOutOfRange(status)) { + file_status_ = status; return Status::OK(); } return status; @@ -172,6 +190,7 @@ Status BufferedInputStream::Reset() { TF_RETURN_IF_ERROR(input_stream_->Reset()); pos_ = 0; limit_ = 0; + file_status_ = Status::OK(); return Status::OK(); } diff --git a/tensorflow/core/lib/io/buffered_inputstream.h b/tensorflow/core/lib/io/buffered_inputstream.h index b37766005a9..2b824f35f80 100644 --- a/tensorflow/core/lib/io/buffered_inputstream.h +++ b/tensorflow/core/lib/io/buffered_inputstream.h @@ -94,6 +94,9 @@ class BufferedInputStream : public InputStreamInterface { size_t pos_ = 0; // current position in buf_. size_t limit_ = 0; // just past the end of valid data in buf_. bool owns_input_stream_ = false; + // When EoF is reached, file_status_ contains the status to skip unnecessary + // buffer allocations. + Status file_status_ = Status::OK(); TF_DISALLOW_COPY_AND_ASSIGN(BufferedInputStream); }; diff --git a/tensorflow/core/lib/io/buffered_inputstream_test.cc b/tensorflow/core/lib/io/buffered_inputstream_test.cc index 7265101e1be..49b2b1a861a 100644 --- a/tensorflow/core/lib/io/buffered_inputstream_test.cc +++ b/tensorflow/core/lib/io/buffered_inputstream_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/io/random_inputstream.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" namespace tensorflow { namespace io { @@ -362,6 +363,45 @@ TEST(BufferedInputStream, ReadAll_Text) { } } +void BM_BufferedReaderSmallReads(const int iters, const int buff_size, + const int file_size) { + testing::StopTiming(); + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/buffered_inputstream_test"; + + const string file_elem = "0123456789"; + std::unique_ptr write_file; + TF_ASSERT_OK(env->NewWritableFile(fname, &write_file)); + for (int i = 0; i < file_size; ++i) { + TF_ASSERT_OK(write_file->Append(file_elem)); + } + TF_ASSERT_OK(write_file->Close()); + + std::unique_ptr file; + TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file)); + + string result; + testing::StartTiming(); + + for (int itr = 0; itr < iters; ++itr) { + BufferedInputStream in(file.get(), buff_size); + for (int64 i = 0; i < 10 * file_size; ++i) { + TF_ASSERT_OK(in.ReadNBytes(1, &result)) + << "i: " << i << " itr: " << itr << " buff_size: " << buff_size + << " file size: " << file_size; + } + } +} +BENCHMARK(BM_BufferedReaderSmallReads) + ->ArgPair(1, 5) + ->ArgPair(1, 1024) + ->ArgPair(10, 5) + ->ArgPair(10, 1024) + ->ArgPair(1024, 1024) + ->ArgPair(1024 * 1024, 1024) + ->ArgPair(1024 * 1024, 1024 * 1024) + ->ArgPair(256 * 1024 * 1024, 1024); + } // anonymous namespace } // namespace io } // namespace tensorflow diff --git a/tensorflow/core/lib/io/zlib_inputstream.h b/tensorflow/core/lib/io/zlib_inputstream.h index a8a4e7c83cc..8faa7dcb8f4 100644 --- a/tensorflow/core/lib/io/zlib_inputstream.h +++ b/tensorflow/core/lib/io/zlib_inputstream.h @@ -37,7 +37,7 @@ namespace io { // by multiple threads class ZlibInputStream : public InputStreamInterface { public: - // Create a ZlibInputBuffer for `input_stream` with a buffer of size + // Create a ZlibInputStream for `input_stream` with a buffer of size // `input_buffer_bytes` bytes for reading contents from `input_stream` and // another buffer with size `output_buffer_bytes` for caching decompressed // contents. Does *not* take ownership of "input_stream". diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 651f22c6eae..62c86c77145 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -5488,24 +5488,28 @@ REGISTER_OP("BatchMatrixDiag") .Input("diagonal: T") .Output("output: T") .Attr("T: type") - .Deprecated(14, "Use MatrixDiag"); + .Deprecated(14, "Use MatrixDiag") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchMatrixSetDiag") .Input("input: T") .Input("diagonal: T") .Output("output: T") .Attr("T: type") - .Deprecated(14, "Use MatrixSetDiag"); + .Deprecated(14, "Use MatrixSetDiag") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchMatrixDiagPart") .Input("input: T") .Output("diagonal: T") .Attr("T: type") - .Deprecated(14, "Use MatrixDiagPart"); + .Deprecated(14, "Use MatrixDiagPart") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchMatrixBandPart") .Input("input: T") .Input("num_lower: int64") .Input("num_upper: int64") .Output("band: T") .Attr("T: type") - .Deprecated(14, "Use MatrixBandPart"); + .Deprecated(14, "Use MatrixBandPart") + .SetShapeFn(shape_inference::UnknownShape); } // namespace tensorflow diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 22d4a0056f8..a8338620d69 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -24390,6 +24390,21 @@ op { type: "type" } } +op { + name: "SerializeTensor" + input_arg { + name: "tensor" + type_attr: "T" + } + output_arg { + name: "serialized" + type: DT_STRING + } + attr { + name: "T" + type: "type" + } +} op { name: "SetSize" input_arg { @@ -24844,6 +24859,51 @@ op { } } } +op { + name: "SloppyInterleaveDataset" + input_arg { + name: "input_dataset" + type: DT_RESOURCE + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "cycle_length" + type: DT_INT64 + } + input_arg { + name: "block_length" + type: DT_INT64 + } + output_arg { + name: "handle" + type: DT_RESOURCE + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} op { name: "Softmax" input_arg { @@ -28817,6 +28877,40 @@ op { } } } +op { + name: "Sub" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } +} op { name: "Substr" input_arg { diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 37d9a737e29..7cc8dccb95c 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -233,6 +233,33 @@ f: A function mapping elements of `input_dataset`, concatenated with `output_types` and `output_shapes`. )doc"); +REGISTER_OP("SloppyInterleaveDataset") + .Input("input_dataset: resource") + .Input("other_arguments: Targuments") + .Input("cycle_length: int64") + .Input("block_length: int64") + .Output("handle: resource") + .Attr("f: func") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that applies `f` to the outputs of `input_dataset`. + +The resulting dataset is similar to the `InterleaveDataset`, with the exception +that if retrieving the next value from a dataset would cause the requester to +block, it will skip that input dataset. This dataset is especially useful +when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it +allows the training step to proceed so long as some data is available. + +!! WARNING !! This dataset is not deterministic! + +f: A function mapping elements of `input_dataset`, concatenated with + `other_arguments`, to a Dataset resource that contains elements matching + `output_types` and `output_shapes`. +)doc"); + REGISTER_OP("GroupByWindowDataset") .Input("input_dataset: resource") .Input("key_func_other_arguments: Tkey_func_other_arguments") diff --git a/tensorflow/core/ops/debug_ops.cc b/tensorflow/core/ops/debug_ops.cc index bd7f7c2c018..5aebdca1ea5 100644 --- a/tensorflow/core/ops/debug_ops.cc +++ b/tensorflow/core/ops/debug_ops.cc @@ -32,6 +32,7 @@ REGISTER_OP("Copy") .Attr("tensor_name: string = ''") .Attr("debug_ops_spec: list(string) = []") .SetAllowsUninitializedInput() + .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( Copy Op. @@ -61,6 +62,7 @@ REGISTER_OP("CopyHost") .Attr("tensor_name: string = ''") .Attr("debug_ops_spec: list(string) = []") .SetAllowsUninitializedInput() + .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( Copy Host Op. @@ -118,6 +120,7 @@ REGISTER_OP("DebugNanCount") .Attr("debug_urls: list(string) = []") .Attr("gated_grpc: bool = false") .SetAllowsUninitializedInput() + .SetShapeFn(shape_inference::ScalarShape) .Doc(R"doc( Debug NaN Value Counter Op @@ -148,6 +151,8 @@ REGISTER_OP("DebugNumericSummary") .Attr("mute_if_healthy: bool = false") .Attr("gated_grpc: bool = false") .SetAllowsUninitializedInput() + // Note: this could return a more specific shape if needed in future. + .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( Debug Numeric Summary Op. diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc index 5b75bda1f1b..48b23623422 100644 --- a/tensorflow/core/ops/linalg_ops.cc +++ b/tensorflow/core/ops/linalg_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" @@ -557,34 +558,39 @@ REGISTER_OP("BatchSelfAdjointEig") .Input("input: T") .Output("output: T") .Attr("T: {double, float}") - .Deprecated(11, "Use SelfAdjointEigV2 instead."); + .Deprecated(11, "Use SelfAdjointEigV2 instead.") + .SetShapeFn(shape_inference::UnknownShape); // Can all be deleted after 9mar2017. REGISTER_OP("BatchMatrixDeterminant") .Input("input: T") .Output("output: T") .Attr("T: {float, double, complex64, complex128}") - .Deprecated(13, "Use MatrixDeterminant instead."); + .Deprecated(13, "Use MatrixDeterminant instead.") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchMatrixInverse") .Input("input: T") .Output("output: T") .Attr("adjoint: bool = False") .Attr("T: {double, float}") - .Deprecated(13, "Use MatrixInverse instead."); + .Deprecated(13, "Use MatrixInverse instead.") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchCholesky") .Input("input: T") .Output("output: T") .Attr("T: {double, float}") - .Deprecated(13, "Use Cholesky instead."); + .Deprecated(13, "Use Cholesky instead.") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchCholeskyGrad") .Input("l: T") .Input("grad: T") .Output("output: T") .Attr("T: {float, double}") - .Deprecated(13, "Use CholeskyGrad instead."); + .Deprecated(13, "Use CholeskyGrad instead.") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchSelfAdjointEigV2") .Input("input: T") @@ -592,7 +598,8 @@ REGISTER_OP("BatchSelfAdjointEigV2") .Output("v: T") .Attr("compute_v: bool = True") .Attr("T: {double, float}") - .Deprecated(13, "Use SelfAdjointEigV2 instead."); + .Deprecated(13, "Use SelfAdjointEigV2 instead.") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchMatrixSolve") .Input("matrix: T") @@ -600,7 +607,8 @@ REGISTER_OP("BatchMatrixSolve") .Output("output: T") .Attr("adjoint: bool = False") .Attr("T: {double, float}") - .Deprecated(13, "Use MatrixSolve instead."); + .Deprecated(13, "Use MatrixSolve instead.") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchMatrixTriangularSolve") .Input("matrix: T") @@ -609,7 +617,8 @@ REGISTER_OP("BatchMatrixTriangularSolve") .Attr("lower: bool = True") .Attr("adjoint: bool = False") .Attr("T: {double, float}") - .Deprecated(13, "Use MatrixTriangularSolve instead."); + .Deprecated(13, "Use MatrixTriangularSolve instead.") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchMatrixSolveLs") .Input("matrix: T") @@ -618,7 +627,8 @@ REGISTER_OP("BatchMatrixSolveLs") .Output("output: T") .Attr("T: {double, float}") .Attr("fast: bool = True") - .Deprecated(13, "Use MatrixSolveLs instead."); + .Deprecated(13, "Use MatrixSolveLs instead.") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchSvd") .Input("input: T") @@ -628,6 +638,7 @@ REGISTER_OP("BatchSvd") .Attr("compute_uv: bool = True") .Attr("full_matrices: bool = False") .Attr("T: {double, float, complex64, complex128}") - .Deprecated(13, "Use Svd instead."); + .Deprecated(13, "Use Svd instead.") + .SetShapeFn(shape_inference::UnknownShape); } // namespace tensorflow diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 35c31c6cb81..cfd3869d059 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -24034,6 +24034,25 @@ op { } summary: "Serialize a `SparseTensor` into a string 3-vector (1-D `Tensor`) object." } +op { + name: "SerializeTensor" + input_arg { + name: "tensor" + description: "A Tensor of type `T`." + type_attr: "T" + } + output_arg { + name: "serialized" + description: "A serialized TensorProto proto of the input tensor." + type: DT_STRING + } + attr { + name: "T" + type: "type" + description: "The type of the input tensor." + } + summary: "Transforms a Tensor into a serialized TensorProto proto." +} op { name: "SetSize" input_arg { @@ -24535,6 +24554,54 @@ op { summary: "Return a slice from \'input\'." description: "The output tensor is a tensor with dimensions described by \'size\'\nwhose values are extracted from \'input\' starting at the offsets in\n\'begin\'.\n\n*Requirements*:\n 0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n)" } +op { + name: "SloppyInterleaveDataset" + input_arg { + name: "input_dataset" + type: DT_RESOURCE + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "cycle_length" + type: DT_INT64 + } + input_arg { + name: "block_length" + type: DT_INT64 + } + output_arg { + name: "handle" + type: DT_RESOURCE + } + attr { + name: "f" + type: "func" + description: "A function mapping elements of `input_dataset`, concatenated with\n`other_arguments`, to a Dataset resource that contains elements matching\n`output_types` and `output_shapes`." + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`." + description: "The resulting dataset is similar to the `InterleaveDataset`, with the exception\nthat if retrieving the next value from a dataset would cause the requester to\nblock, it will skip that input dataset. This dataset is especially useful\nwhen loading data from a variable-latency datastores (e.g. HDFS, GCS), as it\nallows the training step to proceed so long as some data is available.\n\n!! WARNING !! This dataset is not deterministic!" + is_stateful: true +} op { name: "Softmax" input_arg { @@ -28908,6 +28975,10 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 type: DT_INT32 type: DT_INT64 type: DT_COMPLEX64 diff --git a/tensorflow/core/platform/default/logging.h b/tensorflow/core/platform/default/logging.h index 04ff9e12b6f..d5f7350cdd8 100644 --- a/tensorflow/core/platform/default/logging.h +++ b/tensorflow/core/platform/default/logging.h @@ -86,7 +86,7 @@ class LogMessageFatal : public LogMessage { ((lvl) <= ::tensorflow::internal::LogMessage::MinVLogLevel()) #endif -#define VLOG(lvl) \ +#define VLOG(lvl) \ if (TF_PREDICT_FALSE(VLOG_IS_ON(lvl))) \ ::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::INFO) diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc index 50dd0cd58b8..c9b362f1823 100644 --- a/tensorflow/core/platform/env_test.cc +++ b/tensorflow/core/platform/env_test.cc @@ -226,14 +226,28 @@ TEST_F(DefaultEnvTest, RecursivelyCreateDirSubdirsExist) { TEST_F(DefaultEnvTest, LocalFileSystem) { // Test filename with file:// syntax. + int expected_num_files = 0; + std::vector matching_paths; for (const int length : {0, 1, 1212, 2553, 4928, 8196, 9000, (1 << 20) - 1, 1 << 20, (1 << 20) + 1}) { - string filename = io::JoinPath(BaseDir(), strings::StrCat("file", length)); + string filename = io::JoinPath(BaseDir(), strings::StrCat("len", length)); filename = strings::StrCat("file://", filename); // Write a file with the given length const string input = CreateTestFile(env_, filename, length); + ++expected_num_files; + + // Ensure that GetMatchingPaths works as intended. + TF_EXPECT_OK(env_->GetMatchingPaths( + // Try it with the "file://" URI scheme. + strings::StrCat("file://", io::JoinPath(BaseDir(), "l*")), + &matching_paths)); + EXPECT_EQ(expected_num_files, matching_paths.size()); + TF_EXPECT_OK(env_->GetMatchingPaths( + // Try it without any URI scheme. + io::JoinPath(BaseDir(), "l*"), &matching_paths)); + EXPECT_EQ(expected_num_files, matching_paths.size()); // Read the file back and check equality string output; diff --git a/tensorflow/core/util/activation_mode.cc b/tensorflow/core/util/activation_mode.cc index 4bf947a0a9a..efb5ab146aa 100644 --- a/tensorflow/core/util/activation_mode.cc +++ b/tensorflow/core/util/activation_mode.cc @@ -22,7 +22,9 @@ namespace tensorflow { Status GetActivationModeFromString(const string& str_value, ActivationMode* value) { - if (str_value == "Sigmoid") { + if (str_value == "None") { + *value = NONE; + } else if (str_value == "Sigmoid") { *value = SIGMOID; } else if (str_value == "Relu") { *value = RELU; diff --git a/tensorflow/core/util/activation_mode.h b/tensorflow/core/util/activation_mode.h index 2a8564847dd..2e03ccd5c85 100644 --- a/tensorflow/core/util/activation_mode.h +++ b/tensorflow/core/util/activation_mode.h @@ -28,6 +28,7 @@ namespace tensorflow { // ActivationMode: the activation function we apply to the input tensor: enum ActivationMode { + NONE = 0, SIGMOID = 1, RELU = 2, RELU6 = 3, diff --git a/tensorflow/docs_src/programmers_guide/datasets.md b/tensorflow/docs_src/programmers_guide/datasets.md index ba26bd5e941..aaebabfddf9 100644 --- a/tensorflow/docs_src/programmers_guide/datasets.md +++ b/tensorflow/docs_src/programmers_guide/datasets.md @@ -146,6 +146,9 @@ for i in range(100): assert i == value ``` +Note: Currently, one-shot iterators are the only type that is easily usable +with an `Estimator`. + An **initializable** iterator requires you to run an explicit `iterator.initializer` operation before using it. In exchange for this inconvenience, it enables you to *parameterize* the definition of the dataset, @@ -452,6 +455,9 @@ dataset = dataset.flat_map( .filter(lambda line: tf.not_equal(tf.substr(line, 0, 1), "#")))) ``` +For a full example of parsing a CSV file using datasets, see [`imports85.py`](https://www.tensorflow.org/code/tensorflow/examples/get_started/regression/imports85.py) +in @{$get_started/linear_regression}. +