diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 29b726c1a30..c0a47cf6b4a 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -375,6 +375,7 @@ config_setting( package_group( name = "internal", packages = [ + "//learning/meta_rank/...", "//tensorflow/...", "//tensorflow_fold/llgtm/...", ], diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc index 09aee39d8cd..4bc209b7ecf 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc @@ -39,21 +39,23 @@ static void AllocateFlags() { flags->tf_xla_min_cluster_size = 2; flags->tf_xla_max_cluster_size = std::numeric_limits::max(); flags->tf_xla_clustering_debug = false; - flag_list = new std::vector({ - Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit, - "Control compilation of operators into XLA computations on CPU and " - "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for " - "things very likely to be improved; 2 = on for everything. " - "Experimental."), - Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size, - "Minimum number of operators in an XLA compilation. Ignored for " - "operators placed on an XLA device or operators explicitly marked " - "for compilation."), - Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size, - "Maximum number of operators in an XLA compilation."), - Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug, - "Dump graphs during XLA compilation."), - }); + flags->tf_xla_cpu_global_jit = false; + flag_list = new std::vector( + {Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit, + "Control compilation of operators into XLA computations on CPU and " + "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for " + "things very likely to be improved; 2 = on for everything. " + "Experimental."), + Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size, + "Minimum number of operators in an XLA compilation. Ignored for " + "operators placed on an XLA device or operators explicitly marked " + "for compilation."), + Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size, + "Maximum number of operators in an XLA compilation."), + Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug, + "Dump graphs during XLA compilation."), + Flag("tf_xla_cpu_global_jit", &flags->tf_xla_cpu_global_jit, + "Enables global JIT compilation for CPU via SessionOptions.")}); xla::legacy_flags::ParseFlagsFromEnv(*flag_list); } diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h index 24f80507428..e1ccd7ddb87 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h @@ -46,6 +46,8 @@ typedef struct { int32 tf_xla_max_cluster_size; // Maximum number of operators in an XLA // compilation. bool tf_xla_clustering_debug; // Dump graphs during XLA compilation. + bool tf_xla_cpu_global_jit; // Enables global JIT compilation for CPU + // via SessionOptions. } MarkForCompilationPassFlags; // Return a pointer to the MarkForCompilationPassFlags struct; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 78d0aa86a8f..74c9791f5ea 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -290,9 +290,11 @@ Status MarkForCompilationPass::Run( global_jit_level = static_cast(flags->tf_xla_auto_jit); } + bool cpu_global_jit = flags->tf_xla_cpu_global_jit; const FunctionLibraryDefinition* fld = options.flib_def; - auto is_compilable = [global_jit_level, fld](const Node* node, - const DeviceType& device_type) { + + auto is_compilable = [global_jit_level, cpu_global_jit, fld]( + const Node* node, const DeviceType& device_type) { const XlaOpRegistry::DeviceRegistration* registration; if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { @@ -315,7 +317,11 @@ Status MarkForCompilationPass::Run( if (status.ok()) return compile; // Otherwise use the value of global_jit_level. - return registration->enable_jit_by_default && global_jit_level > 0; + // Ignore enable_jit_by_default if global jit compilation for CPU + // is explicitly requested via tf_xla_cpu_global_jit flag + bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU; + return (ignore_registration || registration->enable_jit_by_default) && + global_jit_level > 0; }; return RunImpl(options, is_compilable); } @@ -556,6 +562,7 @@ Status MarkForCompilationPass::RunImpl( if (cluster_sizes[cluster] >= min_cluster_size || marked_for_compilation || registration->requires_compilation) { string& name = cluster_names[cluster]; + if (name.empty()) { name = strings::StrCat("cluster_", cluster_sequence_num++); } diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 376c8108ed6..5a81438b1c4 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -179,7 +179,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 6ef4860f358..40a484da098 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -731,11 +731,12 @@ string DebugString(const Graph& graph, FunctionalizeCond::ClusterHandle::Vector* clusters) { string ret = "digraph {\ncompound=true;labeljust=\"r\";ranksep=0.24\n"; std::map subgraphs; + auto name = [](const Node* n) { + return strings::StrCat(n->type_string(), "_", n->id()); + }; for (Node* n : graph.nodes()) { - if (n->IsOp()) { - strings::StrAppend(&subgraphs[clusters->at(n).Get()], n->id(), - " [label=\"", n->name(), "\"];\n"); - } + strings::StrAppend(&subgraphs[clusters->at(n).Get()], n->id(), " [label=\"", + name(n), "\"];\n"); } for (auto kv : subgraphs) { strings::StrAppend(&ret, "subgraph cluster_", kv.first.ToString(), " {\n", @@ -743,16 +744,11 @@ string DebugString(const Graph& graph, kv.first.ToString(), "\";\n", kv.second, "}\n"); } for (Node* n : graph.nodes()) { - if (!n->IsOp()) { - continue; - } for (Node* in : n->in_nodes()) { - if (in->IsOp()) { - strings::StrAppend(&ret, in->id(), " -> ", n->id(), ";\n"); - } + strings::StrAppend(&ret, in->id(), " -> ", n->id(), ";\n"); } } - return strings::StrCat(ret, "}"); + return strings::StrCat(ret, "} // end"); } string DebugString(const FunctionalizeCond::ClusteredGraph& clustered_graph) { @@ -761,16 +757,24 @@ string DebugString(const FunctionalizeCond::ClusteredGraph& clustered_graph) { return cluster.representative.ToString(); }; for (auto kv : clustered_graph) { - strings::StrAppend(&ret, kv.first.ToString(), " [label=\"", name(kv.second), - " (", kv.second.switch_nodes.size(), ", ", - kv.second.merge_nodes.size(), ")\"];\n"); + if (!kv.second.switch_nodes.empty() || !kv.second.merge_nodes.empty()) { + strings::StrAppend( + &ret, kv.first.ToString(), " [label=\"", name(kv.second), + kv.second.switch_nodes.empty() + ? "" + : strings::StrCat(" switches=", kv.second.switch_nodes.size()), + kv.second.merge_nodes.empty() + ? "" + : strings::StrCat(" merges=", kv.second.merge_nodes.size()), + "\"];\n"); + } } for (auto kv : clustered_graph) { for (auto in : kv.second.in_nodes) { strings::StrAppend(&ret, name(*in), " -> ", name(kv.second), ";\n"); } } - return strings::StrCat(ret, "}"); + return strings::StrCat(ret, "} // end"); } bool IsDeadSwitch(const Node* node) { @@ -790,9 +794,6 @@ bool IsDeadSwitch(const Node* node) { void FunctionalizeCond::CreateClusters() { for (Node* node : graph_->nodes()) { - if (!node->IsOp()) { - continue; - } if (IsSwitch(node)) { switch_nodes_.insert(node); } else if (IsMerge(node)) { @@ -825,6 +826,10 @@ void FunctionalizeCond::CreateClusters() { clusters_.at(node).Merge(&clusters_.at(in)); } } + // Group all source clusters together. + if (node->IsSource() || node->in_edges().empty()) { + clusters_.at(node).Merge(&clusters_.at(ClusterHandle(Graph::kSourceId))); + } } } @@ -876,7 +881,7 @@ void FunctionalizeCond::CreateClusteredGraph() { for (const Node* in : node->in_nodes()) { ClusterHandle other_repr = Representative(in); // Skip source, sink and internal edges. - if (!in->IsOp() || other_repr == repr) { + if (other_repr == repr) { continue; } Cluster& cluster_node_in = clustered_graph_[other_repr]; @@ -887,7 +892,7 @@ void FunctionalizeCond::CreateClusteredGraph() { for (const Node* out : node->out_nodes()) { ClusterHandle other_repr = Representative(out); // Skip source, sink and internal edges. - if (!out->IsOp() || other_repr == repr) { + if (other_repr == repr) { continue; } Cluster& cluster_node_out = clustered_graph_[other_repr]; @@ -897,6 +902,7 @@ void FunctionalizeCond::CreateClusteredGraph() { } return cluster_node; }; + update_cluster_for_node(graph_->source_node()); for (Node* node : switch_nodes_) { update_cluster_for_node(node).switch_nodes.insert(node); } @@ -955,7 +961,7 @@ gtl::optional FunctionalizeCond::GetSwitchCluster( for (Cluster* in : merge_cluster.in_nodes) { Cluster* cluster = in; if (in->switch_nodes.empty()) { - if (in->in_nodes.size() != 1) { + if (in->in_nodes.size() != 1 || in->out_nodes.size() != 1) { return gtl::nullopt; } // There is only a single `in` cluster. @@ -1292,11 +1298,8 @@ std::vector> FunctionalizeCond::SortedMergeNodes() { VLOG(2) << "ProcessClusteredGraph"; std::stack> stack; - for (auto& c : clustered_graph_) { - if (c.second.in_nodes.empty()) { - stack.push({0, &c.second}); - } - } + // Initialize with the source node. + stack.push({0, &clustered_graph_[ClusterHandle(Graph::kSourceId)]}); // Perform a depth-first traversal of the clustered graph computing the // switch-merge depth. diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index 9833323d851..8f78b4c8f90 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -40,6 +40,11 @@ class ConstOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { TensorShape shape(proto_.tensor_shape()); + if (proto_.dtype() == DT_STRING) { + LOG(WARNING) << "Not computing Const of type DT_STRING"; + ctx->SetInvalidOutput(0); + return; + } xla::ComputationBuilder* b = ctx->builder(); // To avoid blowups for large constants filled with the same value, diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index b948dfee6ab..a052bb105e7 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -345,6 +345,16 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { expression->set_constant_value(constant); } +void XlaOpKernelContext::SetInvalidOutput(int index) { + const TensorShape shape; + Tensor* output = nullptr; + OP_REQUIRES_OK(context_, context_->allocate_output(index, shape, &output)); + XlaExpression* expression = CastExpressionFromUninitializedTensor(output); + xla::ComputationDataHandle handle; + handle.set_handle(0); + expression->set_handle(handle); +} + void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { Tensor* output = nullptr; // The shape of the output tensor is the shape of the resource itself diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 5519e89252c..76bcf594e6a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -142,6 +142,10 @@ class XlaOpKernelContext { // SetConstantOutput where possible. void SetConstantOutput(int index, const Tensor& host_tensor); + // Sets output 'index' to an invalid value. + // Any subsequent attempt to consume this output will cause an error. + void SetInvalidOutput(int index); + // Status handling. void SetStatus(const Status& status) { context_->SetStatus(status); } Status status() { return context_->status(); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 3c5b360c8ef..033034b4210 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -265,6 +265,42 @@ bool BufferAssignment::SharesSliceAtIndex( GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie(); } +bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a, + const HloInstruction* hlo_b) const { + using SliceSet = + FlatSet; + // Gets the slices all of instr's subshapes. If any subshape doesn't have an + // assigned slice, returns the empty set. + auto collect_slices = [&](const HloInstruction* instr) -> SliceSet { + SliceSet slices; + Status status = ShapeUtil::ForEachSubshapeWithStatus( + instr->shape(), + [&](const Shape& /*subshape*/, const ShapeIndex& index) { + auto shape_slices = GetAllSlices(instr, index); + if (shape_slices.empty()) { + return InvalidArgument("No slices assigned to part of instr."); + } + slices.insert(shape_slices.begin(), shape_slices.end()); + return Status::OK(); + }); + if (!status.ok()) { + return {}; + } + return slices; + }; + + SliceSet slices_a = collect_slices(hlo_a); + SliceSet slices_b = collect_slices(hlo_b); + // hlo_a and hlo_b have disjoint slices if collect_slices succeeded (i.e. + // didn't return the empty set) for both HLOs, and the two resulting sets of + // slices are disjoint. + return !slices_a.empty() && !slices_b.empty() && + std::none_of(slices_a.begin(), slices_a.end(), + [&](const BufferAllocation::Slice& slice) { + return slices_b.count(slice) > 0; + }); +} + StatusOr BufferAssignment::GetUniqueTopLevelOutputSlice() const { return GetUniqueTopLevelSlice( diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 08a53af8baa..08a40bfeb2a 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -327,6 +327,12 @@ class BufferAssignment { return SharesSliceAtIndex(hlo_a, {}, hlo_b, {}); } + // Returns true if hlo_a and hlo_b both have at least one buffer assigned for + // their top-level and each of their nested shape indices, and if hlo_a's + // buffers are all different from hlo_b's buffers. + bool HaveDisjointSlices(const HloInstruction* hlo_a, + const HloInstruction* hlo_b) const; + // Returns the underlying points-to analysis used for this assignment. const TuplePointsToAnalysis& points_to_analysis() const { return liveness_->points_to_analysis(); diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 3b1900428af..e2e9d2a0c04 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -27,14 +27,8 @@ namespace se = ::perftools::gputools; namespace xla { -/* static */ tensorflow::mutex* Compiler::platform_compiler_mutex_; - -/* static */ void Compiler::LazyInitMutex() { - static std::once_flag mutex_init_flag; - std::call_once(mutex_init_flag, []() { - Compiler::platform_compiler_mutex_ = new tensorflow::mutex; - }); -} +/* static */ tensorflow::mutex Compiler::platform_compiler_mutex_( + tensorflow::LINKER_INITIALIZED); /* static */ std::map* @@ -55,8 +49,7 @@ Compiler::GetPlatformCompilers() { /* static */ void Compiler::RegisterCompilerFactory( se::Platform::Id platform_id, std::function()> compiler_factory) { - LazyInitMutex(); - tensorflow::mutex_lock lock(*platform_compiler_mutex_); + tensorflow::mutex_lock lock(platform_compiler_mutex_); auto* factories = GetPlatformCompilerFactories(); CHECK(factories->find(platform_id) == factories->end()) << "Compiler factory already registered for platform"; @@ -65,8 +58,7 @@ Compiler::GetPlatformCompilers() { /* static */ StatusOr Compiler::GetForPlatform( const se::Platform* platform) { - LazyInitMutex(); - tensorflow::mutex_lock lock(*platform_compiler_mutex_); + tensorflow::mutex_lock lock(platform_compiler_mutex_); auto* compilers = GetPlatformCompilers(); // See if we already instantiated a compiler for this platform. diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 4c2d9600d90..5f021900c8b 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -157,8 +157,7 @@ class Compiler { private: // Mutex that guards the platform-compiler map. - static tensorflow::mutex* platform_compiler_mutex_; - static void LazyInitMutex(); + static tensorflow::mutex platform_compiler_mutex_; // Map from platform kind to compiler factory. static std::map* diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index cdfa30dd9a7..6b7b0d25e87 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -94,7 +94,7 @@ StatusOr ComputationPlacer::AssignDevices( se::Platform::Id platform_id, ComputationPlacerCreationFunction creation_function) { tensorflow::mutex_lock lock( - *ComputationPlacer::platform_computation_placer_mutex()); + ComputationPlacer::platform_computation_placer_mutex_); auto* computation_placers = GetPlatformComputationPlacers(); CHECK(computation_placers->find(platform_id) == computation_placers->end()); (*computation_placers)[platform_id].creation_function = creation_function; @@ -103,7 +103,7 @@ StatusOr ComputationPlacer::AssignDevices( /* static */ StatusOr ComputationPlacer::GetForPlatform( const se::Platform* platform) { tensorflow::mutex_lock lock( - *ComputationPlacer::platform_computation_placer_mutex()); + ComputationPlacer::platform_computation_placer_mutex_); auto* computation_placers = GetPlatformComputationPlacers(); auto it = computation_placers->find(platform->id()); @@ -122,11 +122,9 @@ StatusOr ComputationPlacer::AssignDevices( return it->second.placer.get(); } -/* static */ tensorflow::mutex* -ComputationPlacer::platform_computation_placer_mutex() { - static tensorflow::mutex* m = new tensorflow::mutex; - return m; -} +/* static */ tensorflow::mutex + ComputationPlacer::platform_computation_placer_mutex_( + tensorflow::LINKER_INITIALIZED); /* static */ std::map* diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h index 7d9abcd100d..737ccabaa7a 100644 --- a/tensorflow/compiler/xla/service/computation_placer.h +++ b/tensorflow/compiler/xla/service/computation_placer.h @@ -89,11 +89,8 @@ class ComputationPlacer { const perftools::gputools::Platform* platform); private: - // Routine that returns the mutex that guards the platform-to-computation - // placer map. Done as a routine to ensure correct initialization ordering, - // since RegisterComputationPlacer can be called during program initialization - // time. - static tensorflow::mutex* platform_computation_placer_mutex(); + // The mutex that guards the platform-to-computation placer map. + static tensorflow::mutex platform_computation_placer_mutex_; // State kept for each kind of ComputationPlacer. Registration functions set // up creation_function, and then we use that to lazily create "placer" the diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 4f6e69ebd4e..78216f2ffb9 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -17,6 +17,7 @@ package_group( load(":build_defs.bzl", "runtime_copts") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS") # Filegroup used to collect source files for dependency checking. filegroup( @@ -83,6 +84,7 @@ cc_library( ":cpu_options", ":cpu_parallelization_preparation", ":disassembler", + ":dot_op_emitter", ":ir_emission_utils", ":ir_emitter", ":layout_assignment", @@ -156,21 +158,23 @@ cc_library( ":custom_call_target_registry", ":disassembler", ":external_constant_pool", + ":orc_jit_memory_mapper", ":runtime_conv2d", ":runtime_fork_join", ":runtime_matmul", ":runtime_single_threaded_conv2d", ":runtime_single_threaded_matmul", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", - "@llvm//:core", "@llvm//:execution_engine", + "@llvm//:core", "@llvm//:mc", # fixdeps: keep "@llvm//:orc_jit", "@llvm//:support", "@llvm//:target", # fixdeps: keep - ], + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ] + ORC_JIT_MEMORY_MAPPER_TARGETS, ) cc_library( @@ -282,7 +286,6 @@ cc_library( deps = [ ":cpu_options", ":cpu_runtime", - ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", @@ -619,6 +622,7 @@ cc_library( srcs = ["layout_assignment.cc"], hdrs = ["layout_assignment.h"], deps = [ + ":dot_op_emitter", ":ir_emission_utils", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:computation_layout", @@ -706,6 +710,7 @@ cc_library( srcs = ["parallel_task_assignment.cc"], hdrs = ["parallel_task_assignment.h"], deps = [ + ":dot_op_emitter", ":ir_emission_utils", ":shape_partition", "//tensorflow/compiler/xla/service:hlo", @@ -735,6 +740,16 @@ cc_library( visibility = ["//visibility:public"], ) +cc_library( + name = "orc_jit_memory_mapper", + srcs = ["orc_jit_memory_mapper.cc"], + hdrs = ["orc_jit_memory_mapper.h"], + deps = [ + "//tensorflow/core:lib", + "@llvm//:execution_engine", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index d2202252d95..def801d9d69 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -54,6 +54,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h" #include "tensorflow/compiler/xla/service/cpu/disassembler.h" +#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 2a447a54b01..4c40dae5122 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -23,7 +23,6 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" -#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -950,5 +949,119 @@ llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( return index; } +// Return whether the given shape is a matrix with no padding. +static bool IsRank2WithNoPadding(const Shape& shape) { + return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); +} + +// In a gemm operation where output = lhs * rhs, check whether the given shapes +// are valid for the operation. +static bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, + const Shape& output_shape) { + // The inputs and the output must + // 1) be matrices with no padding, and + // 2) have an allowed element type. + return output_shape.element_type() == F32 && + IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && + IsRank2WithNoPadding(output_shape); +} + +bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { + // For certain types of Dot, we can call Eigen + if (hlo.opcode() == HloOpcode::kDot) { + const Shape& lhs_shape = hlo.operand(0)->shape(); + const Shape& rhs_shape = hlo.operand(1)->shape(); + + if (ShapeUtil::HasZeroElements(lhs_shape) || + ShapeUtil::HasZeroElements(rhs_shape)) { + return false; + } + + if (ProfitableToImplementDotInUntiledLlvmIr(hlo) == + DotInLlvmIrProfitable::kYes || + ProfitableToImplementDotInTiledLlvmIr(hlo)) { + return false; + } + + // If gemm can accept the operand shapes, use it rather than a custom + // kernel. + if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { + // The size of the reduction dimension should match. The shape inference + // guarantees this invariant, so the check here is for programming + // errors. + CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); + return true; + } + } + + if (hlo.opcode() == HloOpcode::kFusion && + hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && + hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { + auto* dot = hlo.fused_expression_root(); + const Shape& lhs_shape = dot->operand(0)->shape(); + const Shape& rhs_shape = dot->operand(1)->shape(); + if (ShapeUtil::HasZeroElements(lhs_shape) || + ShapeUtil::HasZeroElements(rhs_shape)) { + return false; + } + return true; + } + + return false; +} + +DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr( + const HloInstruction& dot) { + if (dot.opcode() == HloOpcode::kDot && dot.shape().dimensions_size() == 2) { + const Shape& result_shape = dot.shape(); + // kReductionDimensionThresholdBytes was chosen to be 1/4 of a typical L1 + // cache line size, so that we can have the reduction dimension of both the + // LHS and RHS matrices and still have some space "left over". This needs + // to be tuned further. + const int64 kReductionDimensionThresholdBytes = 8 * 1024; + const bool single_threaded_eigen = + !dot.GetModule()->config().debug_options().xla_cpu_multi_thread_eigen(); + + // This is the point at which it is better to call into Eigen and shard the + // dot across multiple worker threads. This is a rough estimate by running + // a matmult benchmark on my local machine, and it can be tuned further. + const int64 kMaxSingleThreadedFlops = 16 * 1024; + + const int64 M = result_shape.dimensions(0); + const int64 N = result_shape.dimensions(1); + const int64 K = dot.operand(1)->shape().dimensions(0); + const int64 primitive_type_size = + ShapeUtil::ByteSizeOfPrimitiveType(result_shape.element_type()); + if (M == 1 && + K * primitive_type_size <= kReductionDimensionThresholdBytes && + (single_threaded_eigen || M * K * N <= kMaxSingleThreadedFlops)) { + // Heuristics: + // + // - Look for a configuration where we will likely be able to keep LHS in + // L1 and do a cache-optimal traversal of RHS. + // + // - Bail out on matrices that are large enough that Eigen can profitably + // shard the computation across multiple cores. This only applies when + // multi-threading is enabled. + return LayoutUtil::IsMonotonicWithDim0Major( + dot.operand(1)->shape().layout()) + ? DotInLlvmIrProfitable::kWithColumnMajorRhs + : DotInLlvmIrProfitable::kYes; + } + } + return DotInLlvmIrProfitable::kNo; +} + +bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) { + // Any Matrix-Vector product of floating point or integral type, or + // a transpose-dot fusion of the same can be lowered to a tiled LLVM + // IR implementation. + const Shape& shape = dot.shape(); + return shape.dimensions_size() == 2 && + (shape.dimensions(0) == 1 || shape.dimensions(1) == 1) && + (primitive_util::IsFloatingPointType(shape.element_type()) || + primitive_util::IsIntegralType(shape.element_type())); +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 470bf6ffb4c..c9168ccc0f6 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -30,6 +30,26 @@ limitations under the License. namespace xla { namespace cpu { +bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo); + +enum class DotInLlvmIrProfitable { kYes, kNo, kWithColumnMajorRhs }; + +// Returns a value to indicate if (and under what conditions) will lowering +// |dot| as a untiled LLVM IR dot operation be profitable over calling into +// Eigen or emitting a tiled LLVM IR implementation. Possible return values +// are: +// +// * DotInLlvmIrProfitable::kYes - always profitable. +// * DotInLlvmIrProfitable::kNo - never profitable. +// * DotInLlvmIrProfitable::kWithColumnMajorRhs - only if we can manage to make +// the Rhs layout column major. +DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr( + const HloInstruction& dot); + +// Returns true to indicate that we can generate a tiled LLVM IR implementation +// for |dot|. +bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot); + // Helper class for emitting LLVM IR to perform the dot operation. class DotOpEmitter { public: diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index 7149a193107..cb5cb8a6dd6 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -74,122 +74,5 @@ bool PotentiallyImplementedAsEigenConvolution( kernel_shape.dimensions_size() - 1; } -namespace { - -// Return whether the given shape is a matrix with no padding. -bool IsRank2WithNoPadding(const Shape& shape) { - return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); -} - -// In a gemm operation where output = lhs * rhs, check whether the given shapes -// are valid for the operation. -bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape) { - // The inputs and the output must - // 1) be matrices with no padding, and - // 2) have an allowed element type. - return output_shape.element_type() == F32 && - IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && - IsRank2WithNoPadding(output_shape); -} -} // namespace - -bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { - // For certain types of Dot, we can call Eigen - if (hlo.opcode() == HloOpcode::kDot) { - const Shape& lhs_shape = hlo.operand(0)->shape(); - const Shape& rhs_shape = hlo.operand(1)->shape(); - - if (ShapeUtil::HasZeroElements(lhs_shape) || - ShapeUtil::HasZeroElements(rhs_shape)) { - return false; - } - - if (ProfitableToImplementDotInUntiledLlvmIr(hlo) == - DotInLlvmIrProfitable::kYes || - ProfitableToImplementDotInTiledLlvmIr(hlo)) { - return false; - } - - // If gemm can accept the operand shapes, use it rather than a custom - // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { - // The size of the reduction dimension should match. The shape inference - // guarantees this invariant, so the check here is for programming - // errors. - CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); - return true; - } - } - - if (hlo.opcode() == HloOpcode::kFusion && - hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && - hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { - auto* dot = hlo.fused_expression_root(); - const Shape& lhs_shape = dot->operand(0)->shape(); - const Shape& rhs_shape = dot->operand(1)->shape(); - if (ShapeUtil::HasZeroElements(lhs_shape) || - ShapeUtil::HasZeroElements(rhs_shape)) { - return false; - } - return true; - } - - return false; -} - -DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr( - const HloInstruction& dot) { - if (dot.opcode() == HloOpcode::kDot && dot.shape().dimensions_size() == 2) { - const Shape& result_shape = dot.shape(); - // kReductionDimensionThresholdBytes was chosen to be 1/4 of a typical L1 - // cache line size, so that we can have the reduction dimension of both the - // LHS and RHS matrices and still have some space "left over". This needs - // to be tuned further. - const int64 kReductionDimensionThresholdBytes = 8 * 1024; - const bool single_threaded_eigen = - !dot.GetModule()->config().debug_options().xla_cpu_multi_thread_eigen(); - - // This is the point at which it is better to call into Eigen and shard the - // dot across multiple worker threads. This is a rough estimate by running - // a matmult benchmark on my local machine, and it can be tuned further. - const int64 kMaxSingleThreadedFlops = 16 * 1024; - - const int64 M = result_shape.dimensions(0); - const int64 N = result_shape.dimensions(1); - const int64 K = dot.operand(1)->shape().dimensions(0); - const int64 primitive_type_size = - ShapeUtil::ByteSizeOfPrimitiveType(result_shape.element_type()); - if (M == 1 && - K * primitive_type_size <= kReductionDimensionThresholdBytes && - (single_threaded_eigen || M * K * N <= kMaxSingleThreadedFlops)) { - // Heuristics: - // - // - Look for a configuration where we will likely be able to keep LHS in - // L1 and do a cache-optimal traversal of RHS. - // - // - Bail out on matrices that are large enough that Eigen can profitably - // shard the computation across multiple cores. This only applies when - // multi-threading is enabled. - return LayoutUtil::IsMonotonicWithDim0Major( - dot.operand(1)->shape().layout()) - ? DotInLlvmIrProfitable::kWithColumnMajorRhs - : DotInLlvmIrProfitable::kYes; - } - } - return DotInLlvmIrProfitable::kNo; -} - -bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) { - // Any Matrix-Vector product of floating point or integral type, or - // a transpose-dot fusion of the same can be lowered to a tiled LLVM - // IR implementation. - const Shape& shape = dot.shape(); - return shape.dimensions_size() == 2 && - (shape.dimensions(0) == 1 || shape.dimensions(1) == 1) && - (primitive_util::IsFloatingPointType(shape.element_type()) || - primitive_util::IsIntegralType(shape.element_type())); -} - } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h index cbe07a7c2b9..ac361ddfb4c 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h @@ -23,27 +23,6 @@ namespace cpu { bool PotentiallyImplementedAsEigenConvolution( const HloInstruction& convolution); - -bool PotentiallyImplementedAsEigenDot(const HloInstruction& dot); - -enum class DotInLlvmIrProfitable { kYes, kNo, kWithColumnMajorRhs }; - -// Returns a value to indicate if (and under what conditions) will lowering -// |dot| as a untiled LLVM IR dot operation be profitable over calling into -// Eigen or emitting a tiled LLVM IR implementation. Possible return values -// are: -// -// * DotInLlvmIrProfitable::kYes - always profitable. -// * DotInLlvmIrProfitable::kNo - never profitable. -// * DotInLlvmIrProfitable::kWithColumnMajorRhs - only if we can manage to make -// the Rhs layout column major. -DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr( - const HloInstruction& dot); - -// Returns true to indicate that we can generate a tiled LLVM IR implementation -// for |dot|. -bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot); - } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/layout_assignment.cc index b75ca34e0a8..3f2d101959d 100644 --- a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/layout_assignment.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.cc b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.cc new file mode 100644 index 00000000000..e624e5cc7eb --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.cc @@ -0,0 +1,40 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" + +namespace xla { +namespace cpu { +namespace orc_jit_memory_mapper { + +static tensorflow::mutex mapper_instance_mutex(tensorflow::LINKER_INITIALIZED); +static llvm::SectionMemoryManager::MemoryMapper* mapper_instance + GUARDED_BY(mapper_instance_mutex) = nullptr; + +llvm::SectionMemoryManager::MemoryMapper* GetInstance() { + tensorflow::mutex_lock lock(mapper_instance_mutex); + return mapper_instance; +} + +Registrar::Registrar( + std::unique_ptr mapper) { + tensorflow::mutex_lock lock(mapper_instance_mutex); + mapper_instance = mapper.release(); +} +} // namespace orc_jit_memory_mapper +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h new file mode 100644 index 00000000000..2d29550fd5b --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ + +#include + +#include "llvm/ExecutionEngine/SectionMemoryManager.h" + +namespace xla { +namespace cpu { + +namespace orc_jit_memory_mapper { +// Returns the registered memory mapper if there is one. Returns nullptr if no +// memory mapper is registered. +llvm::SectionMemoryManager::MemoryMapper* GetInstance(); + +class Registrar { + public: + // Registers the `mapper` as a memory mapper. This is a no-op if `mapper` is + // null. Precondition: no other memory mapper has been registered yet. + explicit Registrar( + std::unique_ptr mapper); +}; +} // namespace orc_jit_memory_mapper + +#define XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER(mapper_instance, ctr) \ + static ::xla::cpu::orc_jit_memory_mapper::Registrar \ + XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER_NAME(ctr)(mapper_instance) + +// __COUNTER__ must go through another macro to be properly expanded +#define XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER_NAME(ctr) \ + __orc_jit_memory_mapper_registrar_##ctr + +// Registers the std::unique_ptr +// returned by the `factory` expression. `factory` is allowed to evaluate to +// a null unique_ptr in which case this macro does nothing. +#define XLA_REGISTER_ORC_JIT_MEMORY_MAPPER(factory) \ + XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER(factory, __COUNTER__) +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 4a62a80fac0..4b44ac8941e 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 462614475fc..cda27833079 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" @@ -125,8 +126,10 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, /*MAttrs=*/DetectMachineAttributes()))), disassembler_(*target_machine_), data_layout_(target_machine_->createDataLayout()), - object_layer_( - [] { return std::make_shared(); }), + object_layer_([] { + return std::make_shared( + orc_jit_memory_mapper::GetInstance()); + }), compile_layer_( object_layer_, CompilerFunctor(target_machine_.get(), &disassembler_, opt_level, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 6a0eacc66a5..23fb308ec6b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" #include +#include #include #include @@ -258,7 +259,9 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, return InternalError("couldn't get temp CUBIN file name"); } auto cubin_cleaner = tensorflow::gtl::MakeCleanup([&cubin_path] { - TF_CHECK_OK(tensorflow::Env::Default()->DeleteFile(cubin_path)); + // CUBIN file may never be created, so the failure to delete it should not + // produce TF error. + tensorflow::Env::Default()->DeleteFile(cubin_path).IgnoreError(); }); tensorflow::SubProcess ptxas_info_dumper; std::vector ptxas_args = {ptxas_path, ptx_path, "-o", cubin_path, @@ -500,10 +503,24 @@ std::vector GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx, VLOG(2) << "Compiled PTX size:" << ptx.size() << " CUBIN size: " << cache_value->cubin_data.size(); } else { - LOG(WARNING) - << "Failed to compile ptx to cubin. Will attempt to let " - "GPU driver compile the ptx. " - << maybe_cubin.status(); + bool log_warning = true; + if (maybe_cubin.status().code() == + tensorflow::error::Code::NOT_FOUND) { + // Missing ptxas is expected in some environments where CUDA SDK + // binaries are not available. We don't want to spam logs with + // identical warnings in this case. + + // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N + // for more general usage. + static std::atomic warning_done(false); + log_warning = !warning_done.exchange(true); + } + if (log_warning) { + LOG(WARNING) + << "Failed to compile ptx to cubin. Will attempt to let " + "GPU driver compile the ptx. " + << maybe_cubin.status(); + } } } cache_value->compilation_done = true; diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 163a161353f..c2115c49993 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -166,11 +166,46 @@ void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo, *(base_ptrs_[&hlo].mutable_element(shape_index)) = typed_ir_value; } +// Determines whether hlo's buffers are never modified within the execution of +// consumer. +static bool BuffersInvariantWithinConsumer( + const HloInstruction& hlo, const HloInstruction& consumer, + const BufferAssignment* buffer_assignment) { + // Check if consumer is inside a fusion node -- if so, "dereference" it until + // we get to a non-fusion node. + const HloInstruction* c = &consumer; + while (c->IsFused()) { + c = c->parent()->FusionInstruction(); + } + + // If, after dereferencing c, we end up with a node that's not inside our + // module's top-level computation (say our node is inside a while loop), we + // give up on marking array as invariant, because this HLO may be run multiple + // times (e.g. multiple while loop iterations, or multiple invocations of a + // reducer's computation). TODO(jlebar): We could relax this constraint if we + // emitted an llvm.invariant.group.barrier at the end of the computation. + return c->parent() == c->GetModule()->entry_computation() && + buffer_assignment->HaveDisjointSlices(&hlo, &consumer); +} + llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo, + const HloInstruction& consumer, const ShapeIndex& shape_index) { llvm_ir::IrArray ir_array(GetBasePointer(hlo, shape_index), ShapeUtil::GetSubshape(hlo.shape(), shape_index)); alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array); + + // The GPU backend emits one kernel per top-level HLO, and LLVM views + // execution of one kernel as the "whole program" executed on the GPU. + // Therefore if hlo's output buffer is not modified within consumer, and if + // consumer runs hlo only once (so that it doesn't create two different + // outputs), then we can mark ir_array as invariant over the whole program. + if (BuffersInvariantWithinConsumer(hlo, consumer, buffer_assignment_)) { + VLOG(2) << "Marking " << hlo.name() << " as invariant within " + << consumer.name(); + ir_array.MarkInvariantOverWholeProgram(&module_->getContext()); + } + return ir_array; } diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index a3120f15bcb..62ae1769a1f 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -76,8 +76,15 @@ class HloToIrBindings { return it->second.element(shape_index); } - // Return the underlying IrArray of the output of the given instruction. + // Returns the IrArray which contains the output of hlo. + // + // consumer is the HLO in which this IrArray is used -- we use this to (try + // to) add metadata indicating that the array is invariant within consumer. + // + // To get the buffer into which hlo should write its own output, call + // GetIrArray(hlo, hlo). llvm_ir::IrArray GetIrArray(const HloInstruction& hlo, + const HloInstruction& consumer, const ShapeIndex& shape_index = {}); private: diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index af2a92e11e5..6e2bd4e11d3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -68,7 +68,8 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : hlo->operands()) { operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*operand).EmitReadArrayElement(index, &ir_builder_); + return GetIrArray(*operand, *hlo) + .EmitReadArrayElement(index, &ir_builder_); }; } return EmitTargetElementLoop( @@ -145,7 +146,8 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { for (const HloInstruction* operand : tuple->operands()) { base_ptrs.push_back(GetBasePointer(*operand)); } - llvm_ir::EmitTuple(GetIrArray(*tuple), base_ptrs, &ir_builder_, module_); + llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &ir_builder_, + module_); return Status::OK(); } @@ -334,7 +336,8 @@ Status IrEmitter::HandleSelect(HloInstruction* select) { TF_RET_CHECK(pred->shape().element_type() == PRED); if (ShapeUtil::IsTuple(select->shape())) { - llvm_ir::EmitTupleSelect(GetIrArray(*select), GetIrArray(*pred), + llvm_ir::EmitTupleSelect(GetIrArray(*select, *select), + GetIrArray(*pred, *select), GetBasePointer(*on_true), GetBasePointer(*on_false), &ir_builder_, module_); return Status::OK(); @@ -349,9 +352,9 @@ Status IrEmitter::HandleSelect(HloInstruction* select) { Status IrEmitter::HandleDot(HloInstruction* dot) { auto lhs_instruction = dot->operand(0); auto rhs_instruction = dot->operand(1); - const llvm_ir::IrArray& target_array = GetIrArray(*dot); - const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction); - const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction); + const llvm_ir::IrArray& target_array = GetIrArray(*dot, *dot); + const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction, *dot); + const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction, *dot); const Shape& lhs_shape = lhs_instruction->shape(); const Shape& rhs_shape = rhs_instruction->shape(); @@ -571,7 +574,8 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { // Apply the reduction function to the loaded value. llvm::Value* input_address = - GetIrArray(*arg).EmitArrayElementAddress(input_index, &ir_builder_); + GetIrArray(*arg, *reduce) + .EmitArrayElementAddress(input_index, &ir_builder_); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *function, {accumulator_addr, input_address}, accumulator_addr)); @@ -587,7 +591,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { std::vector parameter_arrays; for (HloInstruction* operand : fusion->operands()) { - parameter_arrays.push_back(GetIrArray(*operand)); + parameter_arrays.push_back(GetIrArray(*operand, *fusion)); } GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &ir_builder_, GetNestedComputer()); @@ -622,7 +626,8 @@ Status IrEmitter::HandleRng(HloInstruction* random) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : random->operands()) { operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*operand).EmitReadArrayElement(index, &ir_builder_); + return GetIrArray(*operand, *random) + .EmitReadArrayElement(index, &ir_builder_); }; } // Emits a single-threaded loop because the loop body generated by the element @@ -631,7 +636,7 @@ Status IrEmitter::HandleRng(HloInstruction* random) { GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_, GetNestedComputer()) .MakeElementGenerator(random, operand_to_generator), - GetIrArray(*random), &ir_builder_) + GetIrArray(*random, *random), &ir_builder_) .EmitLoop(IrName(random)); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 61fdeaa0ee7..9c01f5b7c72 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -105,10 +105,16 @@ class IrEmitter : public DfsHloVisitorWithDefault { explicit IrEmitter(const HloModuleConfig& hlo_module_config, IrEmitterContext* ir_emitter_context, bool is_nested); - // A convenient helper for calling HloToIrBindings::GetIrArray. + // Helper for calling HloToIrBindings::GetIrArray. + // + // Gets the IrArray which contains inst. This array has metadata that makes + // it valid only within the IR that implements consumer. If you are + // implementing an HLO and want to get its own output buffer, call + // GetIrArray(hlo, hlo). llvm_ir::IrArray GetIrArray(const HloInstruction& inst, + const HloInstruction& consumer, const ShapeIndex& shape_index = {}) { - return bindings_.GetIrArray(inst, shape_index); + return bindings_.GetIrArray(inst, consumer, shape_index); } // A convenient helper for calling HloToIrBindings::GetBasePointer. llvm::Value* GetBasePointer(const HloInstruction& inst) const { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 5da1a130d56..5225ff36ff3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -115,7 +115,8 @@ Status IrEmitterNested::HandleParameter(HloInstruction* parameter) { Status IrEmitterNested::EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator) { - return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo), &ir_builder_) + return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), + &ir_builder_) .EmitLoop(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index db78f4b84dc..1b863c9e3c5 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -282,7 +282,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { MakeUnique(std::move(thunks), fusion)); std::vector parameter_arrays; for (HloInstruction* operand : fusion->operands()) { - parameter_arrays.push_back(GetIrArray(*operand)); + parameter_arrays.push_back(GetIrArray(*operand, *fusion)); } GpuElementalIrEmitter elemental_emitter( hlo_module_config_, ir_emitter_context_->llvm_module(), @@ -344,7 +344,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { thunk_sequence_->emplace_back(BuildKernelThunk(fusion)); std::vector operand_arrays; for (HloInstruction* operand : fusion->operands()) { - operand_arrays.push_back(GetIrArray(*operand)); + operand_arrays.push_back(GetIrArray(*operand, *fusion)); } GpuElementalIrEmitter elemental_emitter(hlo_module_config_, ir_emitter_context_->llvm_module(), @@ -355,7 +355,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // Array to write into. Because this is an in-place operation, this is the // same as operand 0's array. - llvm_ir::IrArray output_array = GetIrArray(*fusion); + llvm_ir::IrArray output_array = GetIrArray(*fusion, *fusion); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( update_shape, ir_emitter_context_->device_description()); @@ -693,9 +693,10 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { constexpr int64 tile_size = 32; constexpr int64 num_rows = 8; int64 num_tiles = EmitTranspose021Tiled( - GetIrArray(*(copy->operand(0))) + GetIrArray(*copy->operand(0), *copy) .CastToShape(reduced_input_shape, &ir_builder_), - GetIrArray(*copy).CastToShape(reduced_output_shape, &ir_builder_), + GetIrArray(*copy, *copy) + .CastToShape(reduced_output_shape, &ir_builder_), tile_size, num_rows, &ir_builder_); UpdateLaunchDimensions(LaunchDimensions(num_tiles, num_rows * tile_size), LastThunk(), ir_emitter_context_->llvm_module()); @@ -850,9 +851,11 @@ Status IrEmitterUnnested::EmitColumnReduction( &ir_builder_); const HloInstruction* output = reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; - llvm::Value* output_address = GetIrArray(*output).EmitArrayElementAddress( - llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), &ir_builder_, - "output_element_address"); + llvm::Value* output_address = + GetIrArray(*output, *output) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), + &ir_builder_, "output_element_address"); return EmitAtomicOperationForNestedComputation( *reducer, output_address, partial_reduction_result_address); }; @@ -1116,9 +1119,11 @@ Status IrEmitterUnnested::EmitRowReduction( "lane_id_is_zero", &ir_builder_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &ir_builder_); - llvm::Value* output_address = GetIrArray(*output).EmitArrayElementAddress( - llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_), &ir_builder_, - "output_element_address"); + llvm::Value* output_address = + GetIrArray(*output, *output) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_), + &ir_builder_, "output_element_address"); return EmitAtomicOperationForNestedComputation( *reducer, output_address, partial_reduction_result_address); }; @@ -1258,11 +1263,12 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { MakeUnique(std::move(thunks), reduce)); return EmitReductionToVector( reduce, input->shape(), - [this, input](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*input).EmitReadArrayElement(index, &ir_builder_); + [&](const llvm_ir::IrArray::Index& index) { + return GetIrArray(*input, *reduce) + .EmitReadArrayElement(index, &ir_builder_); }, - [this, init_value](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*init_value) + [&](const llvm_ir::IrArray::Index& index) { + return GetIrArray(*init_value, *reduce) .EmitReadArrayElement(index, &ir_builder_); }, dimensions_to_reduce, reducer); @@ -1426,7 +1432,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( ir_builder_.CreateStore(operand_index[i], selected_index_address_slot); } }; - llvm_ir::IrArray operand_array(GetIrArray(*operand)); + llvm_ir::IrArray operand_array = GetIrArray(*operand, *select_and_scatter); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &ir_builder_); ir_builder_.CreateStore(operand_data, selected_value_address); @@ -1479,9 +1485,10 @@ Status IrEmitterUnnested::HandleSelectAndScatter( ir_builder_.CreateLoad(selected_index_address_slot)); } llvm::Value* source_value_address = - GetIrArray(*source).EmitArrayElementAddress(source_index, &ir_builder_); + GetIrArray(*source, *select_and_scatter) + .EmitArrayElementAddress(source_index, &ir_builder_); llvm::Value* output_value_address = - GetIrArray(*select_and_scatter) + GetIrArray(*select_and_scatter, *select_and_scatter) .EmitArrayElementAddress(selected_index, &ir_builder_); return EmitAtomicOperationForNestedComputation( *select_and_scatter->scatter(), output_value_address, @@ -1758,7 +1765,7 @@ Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo, return EmitTargetElementLoopInThunk( *hlo, [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*init_value) + return GetIrArray(*init_value, *hlo) .EmitReadArrayElement(index, &ir_builder_); }, thunk); @@ -1859,7 +1866,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( UpdateLaunchDimensions(launch_dimensions, thunk, ir_emitter_context_->llvm_module()); if (!hlo.IsMultiOutputFusion()) { - return ParallelLoopEmitter(element_generator, GetIrArray(hlo), + return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo), launch_dimensions, &ir_builder_) .EmitLoop(IrName(&hlo)); } @@ -1867,7 +1874,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( // For multiple outputs fusion, we need to emit each operand and the root. std::vector output_arrays; for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { - output_arrays.push_back(GetIrArray(hlo, {i})); + output_arrays.push_back(GetIrArray(hlo, hlo, {i})); } TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions, &ir_builder_) @@ -1878,7 +1885,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); } ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple(GetIrArray(hlo), tuple_operand_ptrs, &ir_builder_, + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &ir_builder_, module_); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index ff80f18bb56..3f34b9ceb34 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -75,11 +75,43 @@ HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction, std::forward_as_tuple(value_id, instruction, index, is_phi)); CHECK(emplaced.second); + VLOG(4) << "NewHloValue = " << emplaced.first->second.ToShortString(); + return &emplaced.first->second; } -void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) { - values_.erase(value_id); +void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) { + HloValue& value = values_.at(value_id); + VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")"; + + value_ids_to_delete_.push_back(value_id); +} + +void HloDataflowAnalysis::DeleteMarkedValues() { +#ifndef NDEBUG + // Verify that no marked-for-deletion values are in any of the value sets. + tensorflow::gtl::FlatSet id_set(value_ids_to_delete_.begin(), + value_ids_to_delete_.end()); + for (const auto& pair : value_sets_) { + const HloInstruction* instruction = pair.first; + const InstructionValueSet& instruction_value_set = pair.second; + for (const auto& index_value_set : instruction_value_set) { + const HloValueSet& value_set = index_value_set.second; + for (const HloValue* value : value_set.values()) { + DCHECK(!ContainsKey(id_set, value->id())) + << "Value " << value->ToShortString() + << " marked for deletion, but still exists in value set for " + "instruction " + << instruction->name(); + } + } + } +#endif + + for (HloValue::Id value_id : value_ids_to_delete_) { + values_.erase(value_id); + } + value_ids_to_delete_.clear(); } string HloDataflowAnalysis::ToString() const { @@ -121,6 +153,7 @@ bool HloDataflowAnalysis::Phi( HloInstruction* instruction, tensorflow::gtl::ArraySlice inputs) { CHECK(ssa_form_); + VLOG(4) << "Phi(" << instruction->name() << ")"; for (const InstructionValueSet* input : inputs) { DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape())); @@ -183,7 +216,7 @@ bool HloDataflowAnalysis::Phi( } else if (current_value != &new_value) { if (current_value_defined_here) { // Remove the existing phi. - DeleteHloValue(current_value->id()); + MarkValueForDeletion(current_value->id()); } value_set.Clear(); value_set.AddValue(&new_value); @@ -193,7 +226,8 @@ bool HloDataflowAnalysis::Phi( // Multiple distinct values reach this point. A phi value is // necessary. CHECK_GT(input_value_ids.size(), 1); - if (current_value == nullptr || !current_value->is_phi()) { + if (current_value == nullptr || + !(current_value->is_phi() && current_value_defined_here)) { value_set.Clear(); value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true)); changed = true; @@ -485,11 +519,13 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( } } -void HloDataflowAnalysis::UpdateInstructionsAndPropagate( - tensorflow::gtl::ArraySlice instructions) { +void HloDataflowAnalysis::Propagate() { std::queue worklist; - for (HloInstruction* instruction : instructions) { - worklist.push(instruction); + + for (HloComputation* computation : module_->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + worklist.push(instruction); + } } while (!worklist.empty()) { @@ -662,20 +698,17 @@ StatusOr> HloDataflowAnalysis::Run( new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value)); TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets()); + dataflow_analysis->Propagate(); - // Construct list of all instructions to initialize the worklist to propagate - // the data flow. For efficiency sort the instruction in post order so - // producers appear before consumers. - std::vector all_instructions; - for (const HloComputation* computation : module->MakeComputationPostOrder()) { - for (HloInstruction* instruction : - computation->MakeInstructionPostOrder()) { - all_instructions.push_back(instruction); - } - } - dataflow_analysis->UpdateInstructionsAndPropagate(all_instructions); + // Delete all values marked for deletion. + dataflow_analysis->DeleteMarkedValues(); - // Add in positions to all values. + // Gather and set all non-definition positions of all values. Value deletion + // is rare, so just use a vector indexed by Value::Id rather than a map from + // Value::Id to positions. There should be very few holes in the vector, and + // lookup is faster. + std::vector> value_positions( + dataflow_analysis->next_value_id_); for (const HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { for (const auto& pair : @@ -684,13 +717,18 @@ StatusOr> HloDataflowAnalysis::Run( const HloValueSet& value_set = pair.second; for (const HloValue* value : value_set.values()) { if (value->defining_instruction() != instruction) { - dataflow_analysis->GetValue(value->id()) - .AddPosition(instruction, index); + value_positions[value->id()].push_back( + HloPosition{instruction, index}); } } } } } + for (auto& pair : dataflow_analysis->values_) { + HloValue::Id value_id = pair.first; + HloValue& value = pair.second; + value.SetPositionsAndComputeUses(value_positions[value_id]); + } // Construct vector of values. dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size()); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 63467f32060..dfd81ae9510 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -126,13 +126,16 @@ class HloDataflowAnalysis { HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index, bool is_phi = false); - // Delete the HloValue with the given ID. - void DeleteHloValue(HloValue::Id value_id); + // Mark the HloValue with the given ID for deletion. + void MarkValueForDeletion(HloValue::Id value_id); + + // Delete all HloValues marked for deletion. Should be called after + // propagation is complete. + void DeleteMarkedValues(); // Constructs and initializes the InstructionValueSets of all instructions to // contain exactly the HloValues defined by each instruction. These values can - // then propagated throughout the HLO graph by calling - // UpdateInstructionsAndPropagate. + // then propagated throughout the HLO graph by calling Propagate. Status InitializeInstructionValueSets(); // Updates the value set of the given instruction based on the values flowing @@ -152,10 +155,8 @@ class HloDataflowAnalysis { bool UpdateTupleValueSet(HloInstruction* tuple); bool UpdateWhileValueSet(HloInstruction* xla_while); - // Update the value sets of the given instructions and propagate the - // changes to fixed point. - void UpdateInstructionsAndPropagate( - tensorflow::gtl::ArraySlice instructions); + // Propagate the dataflow through the module. + void Propagate(); // Return the result of the SSA Phi function applied to the given inputs at // the given instruction. If skip_top_level is true, then the top level of the @@ -191,6 +192,11 @@ class HloDataflowAnalysis { // A map from instruction to InstructionValueSet. std::unordered_map value_sets_; + // Values marked for deletion during construction. We don't delete them + // immediately because references to them may remain in ValueSets temporarily + // during propagation. After construction, these values are deleted. + std::vector value_ids_to_delete_; + // A vector containing all HloValues sorted by HloValue::Id. std::vector values_vector_; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 66a538fc519..f08f0b1d683 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -211,10 +211,10 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) { HloPosition{nested_tuple, {0, 0}}, HloPosition{nested_tuple, {1, 0}}, HloPosition{nested_tuple, {2}}, HloPosition{gte_tuple, {0}}, HloPosition{gte_out, {}})); - // Constant values should have no uses though one is live out. The positions - // where they appear as operands are on instructions which do not use the - // values (eg, Tuple). - EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty()); + // Constant values should have only a single use, which is the root of the + // computation. + EXPECT_THAT(analysis.GetValueDefinedAt(constant1, /*index=*/{}).uses(), + UnorderedElementsAre(HloUse{gte_out, 0, {0}})); EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty()); // The top-level tuple values are used in GTE instructions. @@ -274,12 +274,11 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) { EXPECT_EQ(analysis.GetUniqueValueAt(call), analysis.GetValueDefinedAt(add)); EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 0, {}})); + UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{add, 0, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{add, 1, {}})); + UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{add, 1, {}})); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); } TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { @@ -323,18 +322,17 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { EXPECT_TRUE(analysis.ValueIsDefinedAt(sub)); EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 0, {}})); + UnorderedElementsAre(HloUse{call1, 0, {}}, HloUse{call2, 0, {}}, + HloUse{add, 0, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{add, 1, {}})); + UnorderedElementsAre(HloUse{call1, 1, {}}, HloUse{call2, 1, {}}, + HloUse{add, 1, {}})); // The Add from the subcomputation is used as both operands of the Subtract. EXPECT_THAT(analysis.GetValueDefinedAt(add).uses(), UnorderedElementsAre(HloUse{sub, 0, {}}, HloUse{sub, 1, {}})); EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); - EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_computation()); } TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) { @@ -408,7 +406,7 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { auto outer_param1 = outer_builder.AddInstruction( HloInstruction::CreateParameter(1, scalar_shape_, "param1")); // Swizzle parameters. - outer_builder.AddInstruction(HloInstruction::CreateCall( + auto nested_call = outer_builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {outer_param1, outer_param0}, inner_computation)); HloComputation* outer_computation = module_->AddEmbeddedComputation(outer_builder.Build()); @@ -418,7 +416,7 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(2.0))); - builder.AddInstruction(HloInstruction::CreateCall( + auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, outer_computation)); module_->AddEntryComputation(builder.Build()); @@ -431,10 +429,14 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { // Verify that the uses of the constants are properly swizzled by parameter // permutation in nested_call. - EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 1, {}})); - EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{add, 0, {}})); + EXPECT_THAT( + analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{nested_call, 1, {}}, + HloUse{add, 1, {}})); + EXPECT_THAT( + analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{nested_call, 0, {}}, + HloUse{add, 0, {}})); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); } @@ -469,7 +471,7 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); - body_builder.AddInstruction( + auto body_root = body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_0, add})); HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); @@ -496,8 +498,6 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - EXPECT_TRUE( - analysis.GetValueDefinedAt(cond_constant).live_out_of_computation()); EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module()); if (ssa_form) { @@ -517,14 +517,14 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { EXPECT_THAT( analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{xla_while, 0, {0}})); + UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{body_root, 0, {}}, + HloUse{xla_while, 0, {0}})); // Constant1 passes through the body and out of the module. EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}) .live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module()); } else { // While instruction and subcomputation parameters should not define values @@ -538,7 +538,6 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); } } @@ -915,9 +914,11 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { HloUse{select12, 1, {}})); // The two constant values just pass through the Selects and are not - // used. They are live out however. - EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty()); - EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty()); + // used except at the root. They are live out however. + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{select1234, 1, {0}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{select1234, 1, {0}})); EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); } @@ -1318,7 +1319,7 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { auto entry = module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); - const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + RunAnalysis(ssa_form); SequentialHloOrdering::HloModuleSequence sequence; sequence.insert({entry, {param, xla_while}}); @@ -1329,12 +1330,6 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { SequentialHloOrdering ordering(module_.get(), sequence); - // 'add' is the body root even though later instructions follow in the order - // like 'dead_negate'. Only 'add' should be live out of the computation. - EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); - EXPECT_FALSE( - analysis.GetValueDefinedAt(dead_negate).live_out_of_computation()); - // 'add' is live out of the body and will interfere with an later instructions // such as 'dead_constant' and 'dead_negate'. EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_constant)); diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index ecce2bd4e51..755374b91d0 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" namespace xla { -HloToProfileIndex::HloToProfileIndex(const HloModule& module) { +HloProfileIndexMap::HloProfileIndexMap(const HloModule& module) { size_t current_profile_index = 0; for (xla::HloComputation* computation : module.MakeComputationPostOrder()) { InsertOrDie(&computation_to_profile_idx_, computation, @@ -41,24 +41,24 @@ HloToProfileIndex::HloToProfileIndex(const HloModule& module) { } static HloProfilePrinter CreateOwnedHloProfilePrinter( - const HloToProfileIndex& hlo_to_profile_index, + const HloProfileIndexMap& hlo_profile_index_map, const HloCostAnalysis& cost_analysis) { using HloComputationInfo = HloProfilePrinter::HloComputationInfo; using HloInstructionInfo = HloProfilePrinter::HloInstructionInfo; HloComputationInfo* computation_infos = - new HloComputationInfo[hlo_to_profile_index.computation_count()]; + new HloComputationInfo[hlo_profile_index_map.computation_count()]; // There are two "indices" in play here. The first one is the index of the // HloComputationInfo or HloInstructionInfo in the array that contains said // HloComputationInfo or HloInstructionInfo. The second index is the index of // the HloComputationInfo or HloInstructionInfo in the profile counters array, - // as decided by hlo_to_profile_index. The latter index is always referred to - // as "profile_index". + // as decided by hlo_profile_index_map. The latter index is always referred + // to as "profile_index". size_t computation_index_in_static_data = 0; - size_t max_profile_index = hlo_to_profile_index.total_count(); - for (const auto& pair : hlo_to_profile_index.computation_to_profile_idx()) { + size_t max_profile_index = hlo_profile_index_map.total_count(); + for (const auto& pair : hlo_profile_index_map.computation_to_profile_idx()) { CHECK_LT(pair.second, max_profile_index); const HloComputation* computation = pair.first; size_t current_computation_index = computation_index_in_static_data++; @@ -85,7 +85,7 @@ static HloProfilePrinter CreateOwnedHloProfilePrinter( instruction_info->bytes_accessed = cost_analysis.bytes_accessed(*hlo); instruction_info->seconds = cost_analysis.seconds(*hlo); instruction_info->profile_index = - hlo_to_profile_index.GetProfileIndexFor(*hlo); + hlo_profile_index_map.GetProfileIndexFor(*hlo); CHECK_LT(instruction_info->profile_index, max_profile_index); } } @@ -109,26 +109,26 @@ static HloProfilePrinter CreateOwnedHloProfilePrinter( }; return HloProfilePrinter(computation_infos, - hlo_to_profile_index.computation_count(), deleter); + hlo_profile_index_map.computation_count(), deleter); } HloExecutionProfile::HloExecutionProfile(const HloModule& module, const HloCostAnalysis& cost_analysis) - : hlo_to_profile_index_(module), + : hlo_profile_index_map_(module), hlo_profile_printer_( - CreateOwnedHloProfilePrinter(hlo_to_profile_index_, cost_analysis)), + CreateOwnedHloProfilePrinter(hlo_profile_index_map_, cost_analysis)), profile_counters_( - /*count*/ hlo_to_profile_index_.total_count(), + /*count*/ hlo_profile_index_map_.total_count(), /*value*/ 0) {} void HloExecutionProfile::SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken) { - profile_counters_[hlo_to_profile_index_.GetProfileIndexFor(*hlo)] = + profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(*hlo)] = cycles_taken; } uint64 HloExecutionProfile::GetCyclesTakenBy(const HloInstruction& hlo) const { - return profile_counters_[hlo_to_profile_index_.GetProfileIndexFor(hlo)]; + return profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(hlo)]; } string HloExecutionProfile::ToString( diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.h b/tensorflow/compiler/xla/service/hlo_execution_profile.h index f945b9d84c6..84702680c0c 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.h +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.h @@ -29,18 +29,18 @@ namespace xla { class HloInstruction; -// Maps all HloInstructions and HloComputions in an HloModule to integers. -// These integers form the contiguous range [0, GetTotalCount()). -class HloToProfileIndex { +// Maps all HloInstructions and HloComputations in an HloModule to integers. +// These integers form the contiguous range [0, total_count()). +class HloProfileIndexMap { public: - // Scans `module` to populate this instance of HloToProfileIndex. - explicit HloToProfileIndex(const HloModule& module); + // Scans `module` to populate this instance of HloProfileIndexMap. + explicit HloProfileIndexMap(const HloModule& module); - HloToProfileIndex(const HloToProfileIndex&) = default; - HloToProfileIndex(HloToProfileIndex&&) = default; + HloProfileIndexMap(const HloProfileIndexMap&) = default; + HloProfileIndexMap(HloProfileIndexMap&&) = default; - HloToProfileIndex& operator=(const HloToProfileIndex&) = default; - HloToProfileIndex& operator=(HloToProfileIndex&&) = default; + HloProfileIndexMap& operator=(const HloProfileIndexMap&) = default; + HloProfileIndexMap& operator=(HloProfileIndexMap&&) = default; size_t GetProfileIndexFor(const HloInstruction& instruction) const { return FindOrDie(instruction_to_profile_idx(), &instruction); @@ -97,14 +97,14 @@ class HloExecutionProfile { // Return the number of cycles this computation took to execute. uint64 total_cycles_executed(const HloComputation& computation) const { - return profile_counters_[hlo_to_profile_index_.GetProfileIndexFor( + return profile_counters_[hlo_profile_index_map_.GetProfileIndexFor( computation)]; } // Record how many cycles a computation took to execute. void set_total_cycles_executed(const HloComputation& computation, uint64 total_cycles_executed) { - profile_counters_[hlo_to_profile_index_.GetProfileIndexFor(computation)] = + profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(computation)] = total_cycles_executed; } @@ -117,9 +117,9 @@ class HloExecutionProfile { string ToString(const DeviceDescription& device_description) const; private: - // hlo_to_profile_index_ maps an Hlo entity (computation or instruction) to an - // index in profile_counters_. - HloToProfileIndex hlo_to_profile_index_; + // hlo_profile_index_map_ maps an Hlo entity (computation or instruction) to + // an index in profile_counters_. + HloProfileIndexMap hlo_profile_index_map_; // Used to print profile_counters_ in a human readable form. HloProfilePrinter hlo_profile_printer_; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 881b7e227c3..d71a4b42c71 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -970,6 +970,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kOutfeed: case HloOpcode::kCrossReplicaSum: return kBrown; + case HloOpcode::kConditional: case HloOpcode::kCustomCall: case HloOpcode::kWhile: case HloOpcode::kCall: @@ -1003,7 +1004,7 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { } string extended_opcode = StrCat(HloOpcodeString(instr->opcode()), - instr->opcode() == HloOpcode::kFusion + instr->opcode() != HloOpcode::kFusion ? "" : StrCat(":", xla::ToString(instr->fusion_kind()))); // If the name does not contain the opcode, render both. diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index d3096231dca..f7b5b265d92 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -43,6 +43,7 @@ limitations under the License. namespace xla { +using tensorflow::str_util::CEscape; using ::tensorflow::str_util::Join; using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; @@ -1209,6 +1210,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( new_operands[2], new_operands[3], new_operands[4], epsilon(), feature_index()); break; + case HloOpcode::kConditional: case HloOpcode::kRecv: case HloOpcode::kRecvDone: case HloOpcode::kSend: @@ -1602,6 +1604,7 @@ bool HloInstruction::IdenticalSlowPath( return dimensions() == other.dimensions(); // These opcodes are not yet supported. + case HloOpcode::kConditional: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kSort: @@ -1965,6 +1968,13 @@ std::vector HloInstruction::ExtraAttributesToString() const { }), "}")); } + if (opcode() == HloOpcode::kInfeed && !infeed_config_.empty()) { + extra.push_back(StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")); + } + if (opcode() == HloOpcode::kOutfeed && !outfeed_config_.empty()) { + extra.push_back( + StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")); + } return extra; } @@ -2347,6 +2357,7 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleSendDone(this); // These opcodes are not handled here. + case HloOpcode::kConditional: case HloOpcode::kTrace: break; } @@ -2920,7 +2931,6 @@ string PaddingConfigToString(const PaddingConfig& padding) { string OpMetadataToString(const OpMetadata& metadata) { std::vector result; - using tensorflow::str_util::CEscape; if (!metadata.op_type().empty()) { result.push_back(StrCat("op_type=\"", CEscape(metadata.op_type()), "\"")); } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index e0d02e0665c..7b070274416 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -58,6 +58,7 @@ namespace xla { V(kClamp, "clamp") \ V(kComplex, "complex") \ V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ + V(kConditional, "conditional") \ V(kConstant, "constant") \ V(kConvert, "convert") \ V(kConvolution, "convolution") \ diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 37009369797..6f6e679a218 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -173,6 +173,19 @@ bool HloOrdering::UseIsBeforeValueDefinition( return true; } } + + // The use at a call occurs before values that are defined in the called + // computation. + if (use.instruction->opcode() == HloOpcode::kCall) { + const HloInstruction* call = use.instruction; + if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), + call->to_apply())) { + VLOG(4) << " use is call " << use.instruction->name() + << " and def is in called computation"; + return true; + } + } + VLOG(4) << " use is not before value"; return false; } @@ -187,23 +200,6 @@ bool HloOrdering::LiveRangeStrictlyBefore( return false; } - // Live-out values from the module can never have ranges strictly before any - // other value. - if (a.live_out_of_module()) { - VLOG(4) << "a is live out of module"; - return false; - } - - // Live-out values of computations can never have ranges strictly before any - // other value in the computation (including values nested in - // subcomputations). - if (a.live_out_of_computation() && - call_graph_->InstructionIsNestedIn(b.defining_instruction(), - a.defining_instruction()->parent())) { - VLOG(4) << "a is live out of computation containing b"; - return false; - } - // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { if (!UseIsBeforeValueDefinition(use, b, dataflow)) { diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index e6cf0d37b8a..05b7dce3d1e 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -71,7 +71,7 @@ HloValue::HloValue(HloValue::Id id, HloInstruction* instruction, const ShapeIndex& index, bool is_phi) : id_(id), is_phi_(is_phi) { // The defining position is always the first element in the positions_ vector. - AddPosition(instruction, index); + positions_.push_back(HloPosition{instruction, index}); } bool HloValue::operator==(const HloValue& other) const { @@ -130,18 +130,14 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, CHECK_LE(operand_number, 2); return operand_number == 0 || index.empty(); - case HloOpcode::kCall: case HloOpcode::kTuple: // These instructions always pass through their operands transparently. return false; + case HloOpcode::kCall: case HloOpcode::kWhile: - // Though the while instructions passes through its operands, we return - // true because in SSA form there may be a Phi at the parameter of the - // while which is considered a use of its incoming value because the Phi - // input values are not passed through into the body computation. Because - // this function is used in both SSA and non-SSA forms of the analysis - // conservatively return true. + // Although call and while instructions pass through their operands, they + // are considered uses. return true; default: @@ -151,103 +147,58 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, } // namespace -void HloValue::AddPosition(HloInstruction* instruction, - const ShapeIndex& index) { - HloPosition new_position{instruction, index}; +void HloValue::SetPositionsAndComputeUses( + tensorflow::gtl::ArraySlice positions) { + CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once."; - // The new position must not already exist in positions_. + // The positions must be unique and should not contain the defining position + // as this is added at construction time. + for (const HloPosition& position_a : positions) { + DCHECK_NE(position_a, defining_position()); + for (const HloPosition& position_b : positions) { + if (&position_a != &position_b) { + DCHECK_NE(position_a, position_b); + } + } + } + + positions_.insert(positions_.end(), positions.begin(), positions.end()); + + // Gather the computation roots at which this value appears. + tensorflow::gtl::FlatSet root_positions; for (const HloPosition& position : positions_) { - DCHECK_NE(position, new_position); - } - - positions_.push_back(std::move(new_position)); - - // Update uses. - for (HloInstruction* user : instruction->users()) { - for (int64 operand_number : user->OperandIndices(instruction)) { - if (MayUseOperandValue(operand_number, index, user)) { - HloUse new_use{user, operand_number, index}; - - // The new use must not already exist in uses_. - for (const HloUse& use : uses_) { - DCHECK_NE(use, new_use); - } - - uses_.push_back(std::move(new_use)); - } + if (position.instruction == + position.instruction->parent()->root_instruction()) { + root_positions.insert(position.instruction); } } - // Update liveout status of this HloValue. - const HloModule& module = *instruction->parent()->parent(); - if (instruction == module.entry_computation()->root_instruction()) { - live_out_of_module_ = true; - } - - if (instruction == instruction->parent()->root_instruction()) { - live_out_of_computation_ = true; - } -} - -void HloValue::RemovePosition(HloInstruction* instruction, - const ShapeIndex& index) { - // The defining position cannot be removed. - CHECK(!(instruction == defining_instruction() && index == defining_index())); - - int64 size_before = positions_.size(); - positions_.erase( - std::remove_if(positions_.begin(), positions_.end(), - [instruction, &index](const HloPosition& position) { - return position.instruction == instruction && - position.index == index; - }), - positions_.end()); - // Only a single position should have been removed. - CHECK_EQ(positions_.size(), size_before - 1); - - // Update uses which referred to this position. - uses_.erase(std::remove_if(uses_.begin(), uses_.end(), - [instruction, &index](const HloUse& use) { - return use.instruction->operand( - use.operand_number) == instruction && - use.operand_index == index; - }), - uses_.end()); - - // Returns whether this value is contained in the given instruction's output. - auto is_contained_in = [this](const HloInstruction* instruction) { - for (const HloPosition& position : positions()) { - if (position.instruction == instruction) { - return true; - } - } - return false; - }; - - const HloModule& module = *instruction->parent()->parent(); - if (instruction == module.entry_computation()->root_instruction()) { - // Value has been removed from a position in the entry root instruction. - live_out_of_module_ = - is_contained_in(module.entry_computation()->root_instruction()); - } - if (instruction == defining_instruction()->parent()->root_instruction()) { - // Value has been removed from the root of the computation the value has - // been defined in. - live_out_of_computation_ = - is_contained_in(defining_instruction()->parent()->root_instruction()); - } -} - -void HloValue::RecomputeUses() { - uses_.clear(); - for (const HloPosition& position : positions()) { + // Build vector of HloUses for the value. + for (const HloPosition& position : positions_) { for (HloInstruction* user : position.instruction->users()) { for (int64 operand_number : user->OperandIndices(position.instruction)) { - if (MayUseOperandValue(operand_number, position.index, user)) { - uses_.push_back(HloUse{user, operand_number, position.index}); + // Root instructions of computations are considered to be uses whether + // or not the root instruction itself actually uses the value. + if (MayUseOperandValue(operand_number, position.index, user) || + ContainsKey(root_positions, user)) { + HloUse new_use{user, operand_number, position.index}; + + // The new use must not already exist in uses_. + for (const HloUse& use : uses_) { + DCHECK_NE(use, new_use); + } + + uses_.push_back(std::move(new_use)); } } } + + // Update liveout status of this HloValue. + const HloModule& module = *position.instruction->parent()->parent(); + if (position.instruction == + module.entry_computation()->root_instruction()) { + live_out_of_module_ = true; + } } } diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h index 6872bc76a82..2a711e8b425 100644 --- a/tensorflow/compiler/xla/service/hlo_value.h +++ b/tensorflow/compiler/xla/service/hlo_value.h @@ -121,6 +121,12 @@ class HloValue { HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index, bool is_phi = false); + // Sets the positions in the module at which the HloValue appears. Updates + // uses. Should be called once and only once. The defining position should not + // be included in 'positions' as this is set at construction time. + void SetPositionsAndComputeUses( + tensorflow::gtl::ArraySlice positions); + // Return a unique identifier for this HloValue. This value is used for stable // sorting and iteration Id id() const { return id_; } @@ -143,28 +149,15 @@ class HloValue { // Return the shape of this HloValue. const Shape& shape() const { return defining_position().shape(); } - // Add or remove a position at which the HloValue appears. The definition - // position can not be removed. The uses of the HloValue are updated. - void AddPosition(HloInstruction* instruction, const ShapeIndex& index); - void RemovePosition(HloInstruction* instruction, const ShapeIndex& index); - - // Remove all positions except the defining position. Updates uses. - void ClearPositions(); - // Return all positions of the HloValue in the module. const std::vector& positions() const { return positions_; } // Return all uses of the HloValue. const std::vector& uses() const { return uses_; } - void RecomputeUses(); - // Get whether this HloValue is live out of the module. bool live_out_of_module() const { return live_out_of_module_; } - // Get whether this HloValue is live out of the computation it is defined in. - bool live_out_of_computation() const { return live_out_of_computation_; } - bool operator==(const HloValue& other) const; bool operator!=(const HloValue& other) const; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index dea47b1fd7b..de4804996f8 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -92,6 +92,7 @@ namespace xla { case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kConvolution: case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index bdddc232ef7..21bca1d6bef 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -83,7 +83,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, if (std::find(parameter_instructions.begin(), parameter_instructions.end(), &hlo) != parameter_instructions.end()) { - array->AddInvariantLoad(llvm::MDNode::get(*context_, /*MDs=*/{})); + array->MarkInvariantOverWholeProgram(context_); } } } diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index e3f98ac13e7..7224bd68984 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -256,10 +256,10 @@ void IrArray::AnnotateLoadStoreInstructionWithMetadata( llvm::Instruction* instruction) const { CHECK(llvm::isa(instruction) || llvm::isa(instruction)); + CHECK(!llvm::isa(instruction) || !is_invariant_) + << "Trying to create a store to an invariant IRArray."; for (const auto& kind_md_pair : metadata_) { - CHECK(kind_md_pair.first != llvm::LLVMContext::MD_invariant_load || - llvm::isa(instruction)); instruction->setMetadata(kind_md_pair.first, kind_md_pair.second); } } diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 1ed7e99a829..387d4629125 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -229,9 +229,33 @@ class IrArray { AddMetadata(llvm::LLVMContext::MD_noalias, noalias); } - void AddInvariantLoad(llvm::MDNode* invariant_load) { - CHECK_NE(invariant_load, nullptr); - AddMetadata(llvm::LLVMContext::MD_invariant_load, invariant_load); + // Promises LLVM that the data pointed to by this IrArray never changes after + // it's first loaded. + // + // The temporal scope of this promise is the "whole program" from LLVM's point + // of view, but how this translates to HLOs differs between backends. + // + // In the single-threaded CPU backend, we emit one function that + // runs all the HLOs in sequence, so the whole program is the whole HLO + // module. + // + // In the GPU backend, we emit one GPU kernel per top-level HLO (i.e. per HLO + // in the entry computation). From LLVM's perspective, launching a new kernel + // is like launching a new program, and so the whole program is one top-level + // HLO. Since the scope of the promise is smaller than in the CPU backend, we + // can mark more things as invariant in the GPU backend. + // + // Marking loads as invariant is particularly helpful on GPUs because + // invariant loads can be lowered to PTX ld.global.nc (equivalent to CUDA's + // __ldg intrinsic). These loads use a special cache, and can be + // significantly faster than regular loads. + void MarkInvariantOverWholeProgram(llvm::LLVMContext* context) { + if (is_invariant_) { + return; + } + is_invariant_ = true; + AddMetadata(llvm::LLVMContext::MD_invariant_load, + llvm::MDNode::get(*context, {})); } const std::map& metadata() const { return metadata_; } @@ -261,6 +285,8 @@ class IrArray { // loads/stores for this array. They keys are the metadata kinds and the // values are the metadata nodes. std::map metadata_; + + bool is_invariant_ = false; }; } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index a2a442eb1a3..a57ebf59e76 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -63,6 +63,14 @@ void ShapedBuffer::clear() { } } +void ShapedBuffer::AddBufferAtIndex( + const perftools::gputools::DeviceMemoryBase& buffer, + const ShapeIndex& shape_index) { + *mutable_shape_index_to_buffer_entry()->mutable_element(shape_index) = + buffers().size(); + mutable_buffers()->push_back(buffer); +} + const se::DeviceMemoryBase& ShapedBuffer::buffer( const ShapeIndex& index) const { return buffers_[shape_index_to_buffer_entry_.element(index)]; diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index e5ea06fb136..b440948700f 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -75,6 +75,10 @@ class ShapedBuffer { // Set all device memory pointers in the object to null. void clear(); + // Adds a new buffer at the given shape index. + void AddBufferAtIndex(const perftools::gputools::DeviceMemoryBase& buffer, + const ShapeIndex& shape_index); + protected: // The shape of the device buffer with layout. const Shape shape_; diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 4da0a0d3684..fef131d19fc 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -28,12 +28,9 @@ limitations under the License. namespace se = ::perftools::gputools; namespace xla { - -/* static */ tensorflow::mutex* -TransferManager::platform_transfer_manager_mutex() { - static tensorflow::mutex* m = new tensorflow::mutex; - return m; -} +/* static */ tensorflow::mutex + TransferManager::platform_transfer_manager_mutex_( + tensorflow::LINKER_INITIALIZED); /* static */ std::map* @@ -47,7 +44,7 @@ TransferManager::GetPlatformTransferManagers() { se::Platform::Id platform_id, TransferManagerCreationFunction creation_function) { tensorflow::mutex_lock lock( - *TransferManager::platform_transfer_manager_mutex()); + TransferManager::platform_transfer_manager_mutex_); auto* managers = GetPlatformTransferManagers(); CHECK(managers->find(platform_id) == managers->end()); (*managers)[platform_id].creation_function = creation_function; @@ -56,7 +53,7 @@ TransferManager::GetPlatformTransferManagers() { /* static */ StatusOr TransferManager::GetForPlatform( const se::Platform* platform) { tensorflow::mutex_lock lock( - *TransferManager::platform_transfer_manager_mutex()); + TransferManager::platform_transfer_manager_mutex_); auto* managers = GetPlatformTransferManagers(); auto it = managers->find(platform->id()); diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 057bdffe931..d7f85f5765e 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -158,11 +158,8 @@ class TransferManager { const perftools::gputools::Platform* platform); private: - // Routine that returns the mutex that guards the - // platform-to-transfer manager map. Done as a routine to - // ensure correct initialization ordering, since RegisterTransferManager - // can be called during program initialization time. - static tensorflow::mutex* platform_transfer_manager_mutex(); + // The mutex that guards the platform-to-transfer manager map. + static tensorflow::mutex platform_transfer_manager_mutex_; // State kept for each kind of TransferManager. Registration functions // set up creation_function, and then we use that to lazily create diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 2202b6a2c13..c0a0e13f073 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -592,10 +592,10 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return sizeof(uint32); case U64: return sizeof(uint64); - case F16: - return sizeof(float) / 2; case BF16: return sizeof(float) / 2; + case F16: + return sizeof(float) / 2; case F32: return sizeof(float); case F64: diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 3e62481629a..63c3541e14f 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -69,7 +69,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_headers_lib", ], ) diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 36d10fff540..f594c609db6 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -248,5 +248,6 @@ def generate_backend_test_macros(backends=[]): deps = [ "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", "//tensorflow/core:test", ]) diff --git a/tensorflow/compiler/xla/tests/test_macros.cc b/tensorflow/compiler/xla/tests/test_macros.cc index 173fb1b0008..978a669bcab 100644 --- a/tensorflow/compiler/xla/tests/test_macros.cc +++ b/tensorflow/compiler/xla/tests/test_macros.cc @@ -21,12 +21,13 @@ limitations under the License. #include #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/regexp.h" namespace xla { namespace { // Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is -// disabled. +// disabled - a sequence of regexps. using ManifestT = std::unordered_map>; ManifestT ReadManifest() { @@ -66,9 +67,6 @@ ManifestT ReadManifest() { string PrependDisabledIfIndicated(const string& test_case_name, const string& test_name) { - // TODO(leary): this code reads the manifest for every test case instantiated - // in every file. Consider switching to a singleton or using a compile-time - // genrule instead. ManifestT manifest = ReadManifest(); // First try full match: test_case_name.test_name @@ -83,11 +81,13 @@ string PrependDisabledIfIndicated(const string& test_case_name, } } + // Expect a full match vs. one of the platform regexps to disable the test. const std::vector& disabled_platforms = it->second; string platform_string = XLA_PLATFORM; - if (std::find(disabled_platforms.begin(), disabled_platforms.end(), - platform_string) != disabled_platforms.end()) { - return "DISABLED_" + test_name; + for (const auto& s : disabled_platforms) { + if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) { + return "DISABLED_" + test_name; + } } // We didn't hit in the disabled manifest entries, so don't disable it. diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h index bea0b5ef92a..28a2d0198a7 100644 --- a/tensorflow/compiler/xla/tests/test_macros.h +++ b/tensorflow/compiler/xla/tests/test_macros.h @@ -66,8 +66,10 @@ limitations under the License. namespace xla { -// Reads a disabled manifest file (and retains it as a singleton) to resolve -// whether test cases should be disabled on a particular platform. +// Reads a disabled manifest file to resolve whether test cases should be +// disabled on a particular platform. For a test that should be disabled, +// returns DISABLED_ prepended to its name; otherwise returns the test name +// unmodified. string PrependDisabledIfIndicated(const string& test_case_name, const string& test_name); diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index cdd3d66bbba..0d56c9f4836 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -14,8 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/tests/test_utils.h" - #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" namespace xla { @@ -46,6 +47,44 @@ void PopulateWithRandomIntegralData(Literal* literal) { })); } +bool LooksLikeSum(const HloInstruction& instruction) { + return instruction.opcode() == HloOpcode::kAdd && + instruction.operand(0)->opcode() == HloOpcode::kParameter && + instruction.operand(1)->opcode() == HloOpcode::kParameter && + instruction.operand(0) != instruction.operand(1); +} + +// Given an instruction and operand number, replace the given operand with +// a Literal Constant Zero. Handle the case of a fusion instruction by +// replacing the fusion's parent's parameter with a Literal Constant Zero, +// unless the fusion's parent is itself a fusion. +Status MaybeReplaceParameterInputWithZero(HloInstruction* const instruction, + const int64 operand_number) { + CHECK_LT(operand_number, instruction->operand_count()); + if (instruction->operand(operand_number)->opcode() != HloOpcode::kParameter) { + return Status::OK(); + } + + HloComputation* const computation = instruction->parent(); + std::unique_ptr zero = HloInstruction::CreateConstant( + MakeUnique(Literal::Zero(instruction->shape().element_type()))); + + if (computation->IsFusionComputation()) { + HloInstruction* const fusion_instruction = computation->FusionInstruction(); + if (fusion_instruction->IsFused()) { + return Unimplemented( + "Unable to replace fused parameter of fusion instruction"); + } + TF_RETURN_IF_ERROR(fusion_instruction->ReplaceOperandWith( + instruction->operand(operand_number)->parameter_number(), + fusion_instruction->parent()->AddInstruction(std::move(zero)))); + } else { + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith( + operand_number, computation->AddInstruction(std::move(zero)))); + } + return Status::OK(); +} + } // namespace StatusOr> MakeFakeLiteral(const Shape& shape) { @@ -117,4 +156,32 @@ StatusOr>> MakeFakeArguments( return std::move(arguments); } +Status ReplaceInitsWithConstants(HloModule* const module) { + for (HloComputation* const computation : module->computations()) { + for (HloInstruction* const instruction : computation->instructions()) { + const HloOpcode opcode = instruction->opcode(); + if ((opcode == HloOpcode::kReduce || + opcode == HloOpcode::kReduceWindow) && + LooksLikeSum(*instruction->to_apply()->root_instruction())) { + TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 1)); + } else if (opcode == HloOpcode::kSelectAndScatter && + LooksLikeSum(*instruction->scatter()->root_instruction())) { + TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 2)); + } + } + } + return Status::OK(); +} + +Status VerifyHloModule(const perftools::gputools::Platform& platform, + HloModule* const module) { + return HloVerifier( + std::bind( + &TransferManager::GetByteSizeRequirement, + TransferManager::GetForPlatform(&platform).ConsumeValueOrDie(), + std::placeholders::_1)) + .Run(module) + .status(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index 12d5255fce5..9aca162a185 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/platform.h" namespace xla { @@ -62,6 +63,16 @@ StatusOr> MakeFakeLiteral(const Shape& shape); StatusOr>> MakeFakeArguments( const HloModule& module); +// Reductions using Adds, ReduceWindow, and SelectAndScatter, require their +// init_value to be replaced with the constant 0.0f when testing, otherwise we +// may generate a bad init_value when looking at the op in isolation. +Status ReplaceInitsWithConstants(HloModule* const module); + +// Check that a given module satisfies various constraints before trying to +// execute it. +Status VerifyHloModule(const perftools::gputools::Platform& platform, + HloModule* const module); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 3e3406e658f..0159d03b11d 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -776,11 +776,32 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, shape, *fusion_kind, operands, *fusion_computation)); break; } + case HloOpcode::kInfeed: { + optional config; + attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config}; + if (!ParseOperands(&operands, /*expected_size=*/0) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateInfeed(shape, config ? *config : "")); + break; + } + case HloOpcode::kOutfeed: { + optional config; + attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateOutfeed( + shape, operands[0], config ? *config : "")); + break; + } + case HloOpcode::kConditional: case HloOpcode::kCustomCall: case HloOpcode::kReducePrecision: case HloOpcode::kRng: - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); @@ -805,7 +826,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, instruction->set_metadata(*metadata); } return AddInstruction(name, instruction); -} +} // NOLINT(readability/fn_size) // ::= '{' (single_sharding | tuple_sharding) '}' // diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index 8eeed339b87..0ebc0ca44bb 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -560,6 +560,20 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] { ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation } +)" +}, +// infeed/outfeed +{ +"InfeedOutfeed", +R"(HloModule outfeed_module: + +ENTRY %InfeedToOutfeed () -> (u32[3], pred[]) { + %infeed = (u32[3]{0}, pred[]) infeed() + %outfeed = () outfeed((u32[3]{0}, pred[]) %infeed) + ROOT %infeed.1 = (u32[3]{0}, pred[]) infeed() + %outfeed.1 = () outfeed((u32[3]{0}, pred[]) %infeed.1) +} + )" } }); @@ -866,7 +880,7 @@ TEST_F(HloParserTest, CommaBetweenSubAttributes) { const string original = R"(HloModule test_comma_module: ENTRY %test_comma.v4 () -> f32[] { - ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="const"} + ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="::const"} } )"; diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index 3fa5bcc1df4..6b136d333bb 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -17,3 +17,5 @@ def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0): protoc="@protobuf_archive//:protoc", testonly=testonly, visibility=visibility,) + +ORC_JIT_MEMORY_MAPPER_TARGETS = [] diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD index f92b57869ed..a262d4aecdb 100644 --- a/tensorflow/contrib/bayesflow/BUILD +++ b/tensorflow/contrib/bayesflow/BUILD @@ -19,6 +19,7 @@ py_library( srcs = ["__init__.py"] + glob(["python/ops/*.py"]), srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/framework:framework_py", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", @@ -32,7 +33,6 @@ py_library( "//tensorflow/python:random_ops", "//tensorflow/python:state_ops", "//tensorflow/python:util", - "//tensorflow/python/ops/distributions", "//third_party/py/numpy", ], ) @@ -99,6 +99,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "layers_dense_variational_test", + size = "small", + srcs = ["python/kernel_tests/layers_dense_variational_test.py"], + additional_deps = [ + ":bayesflow_py", + "//third_party/py/numpy", + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/python/ops/distributions", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:gradients", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + ], +) + cuda_py_test( name = "monte_carlo_test", size = "small", @@ -160,6 +179,27 @@ cuda_py_test( ], ) +cuda_py_test( + name = "sgld_optimizer_test", + size = "small", + srcs = ["python/kernel_tests/sgld_optimizer_test.py"], + additional_deps = [ + ":bayesflow_py", + "//third_party/py/numpy", + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python/ops/distributions", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_seed", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py index beaf6f1854d..95b9452b1ad 100644 --- a/tensorflow/contrib/bayesflow/__init__.py +++ b/tensorflow/contrib/bayesflow/__init__.py @@ -25,16 +25,28 @@ from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence from tensorflow.contrib.bayesflow.python.ops import custom_grad from tensorflow.contrib.bayesflow.python.ops import halton_sequence from tensorflow.contrib.bayesflow.python.ops import hmc +from tensorflow.contrib.bayesflow.python.ops import layers from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings from tensorflow.contrib.bayesflow.python.ops import monte_carlo +from tensorflow.contrib.bayesflow.python.ops import optimizers # pylint: enable=unused-import,line-too-long from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['csiszar_divergence', 'custom_grad', 'entropy', - 'metropolis_hastings', 'monte_carlo', 'halton_sequence', - 'hmc', 'special_math', 'stochastic_variables', - 'variational_inference'] +_allowed_symbols = [ + 'csiszar_divergence', + 'custom_grad', + 'entropy', + 'halton_sequence', + 'hmc', + 'layers', + 'metropolis_hastings', + 'monte_carlo', + 'optimizers', + 'special_math', + 'stochastic_variables', + 'variational_inference', +] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py new file mode 100644 index 00000000000..50358fd1c2b --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py @@ -0,0 +1,304 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for dense Bayesian layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.bayesflow.python.ops import layers_dense_variational_impl as prob_layers_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.platform import test + + +class Counter(object): + """Helper class to manage incrementing a counting `int`.""" + + def __init__(self): + self._value = -1 + + @property + def value(self): + return self._value + + def __call__(self): + self._value += 1 + return self._value + + +class MockDistribution(normal_lib.Normal): + """Monitors DenseVariational calls to the underlying distribution.""" + + def __init__(self, result_sample, result_log_prob, loc=None, scale=None): + self.result_sample = result_sample + self.result_log_prob = result_log_prob + self.result_loc = loc + self.result_scale = scale + self.called_log_prob = Counter() + self.called_sample = Counter() + self.called_loc = Counter() + self.called_scale = Counter() + + def log_prob(self, *args, **kwargs): + self.called_log_prob() + return self.result_log_prob + + def sample(self, *args, **kwargs): + self.called_sample() + return self.result_sample + + @property + def loc(self): + self.called_loc() + return self.result_loc + + @property + def scale(self): + self.called_scale() + return self.result_scale + + +class MockKLDivergence(object): + """Monitors DenseVariational calls to the divergence implementation.""" + + def __init__(self, result): + self.result = result + self.args = [] + self.called = Counter() + + def __call__(self, *args, **kwargs): + self.called() + self.args.append(args) + return self.result + + +class DenseVariationalLocalReparametrization(test.TestCase): + + def testKLPenaltyKernel(self): + with self.test_session(): + dense_vi = prob_layers_lib.DenseVariational(units=2) + inputs = random_ops.random_uniform([2, 3], seed=1) + + # No keys. + loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(loss_keys), 0) + self.assertListEqual(dense_vi.losses, loss_keys) + + _ = dense_vi(inputs) + + # Yes keys. + loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(loss_keys), 1) + self.assertListEqual(dense_vi.losses, loss_keys) + + def testKLPenaltyBoth(self): + def _make_normal(dtype, *args): # pylint: disable=unused-argument + return normal_lib.Normal( + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)) + with self.test_session(): + dense_vi = prob_layers_lib.DenseVariational( + units=2, + bias_posterior_fn=prob_layers_lib.default_mean_field_normal_fn(), + bias_prior_fn=_make_normal) + inputs = random_ops.random_uniform([2, 3], seed=1) + + # No keys. + loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(loss_keys), 0) + self.assertListEqual(dense_vi.losses, loss_keys) + + _ = dense_vi(inputs) + + # Yes keys. + loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(loss_keys), 2) + self.assertListEqual(dense_vi.losses, loss_keys) + + def testVariationalNonLocal(self): + batch_size, in_size, out_size = 2, 3, 4 + with self.test_session() as sess: + seed = Counter() + inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) + + kernel_size = [in_size, out_size] + kernel_posterior = MockDistribution( + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_divergence = MockKLDivergence( + result=random_ops.random_uniform(kernel_size, seed=seed())) + + bias_size = [out_size] + bias_posterior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_divergence = MockKLDivergence( + result=random_ops.random_uniform(bias_size, seed=seed())) + + expected_outputs = ( + math_ops.matmul(inputs, kernel_posterior.result_sample) + + bias_posterior.result_sample) + + dense_vi = prob_layers_lib.DenseVariational( + units=2, + kernel_use_local_reparameterization=False, + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + kernel_prior_fn=lambda *args: kernel_prior, + kernel_divergence_fn=kernel_divergence, + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + bias_prior_fn=lambda *args: bias_prior, + bias_divergence_fn=bias_divergence) + + outputs = dense_vi(inputs) + + kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + + [ + expected_outputs_, actual_outputs_, + expected_kernel_, actual_kernel_, + expected_kernel_divergence_, actual_kernel_divergence_, + expected_bias_, actual_bias_, + expected_bias_divergence_, actual_bias_divergence_, + ] = sess.run([ + expected_outputs, outputs, + kernel_posterior.result_sample, dense_vi.kernel.posterior_tensor, + kernel_divergence.result, kl_penalty[0], + bias_posterior.result_sample, dense_vi.bias.posterior_tensor, + bias_divergence.result, kl_penalty[1], + ]) + + self.assertAllClose( + expected_kernel_, actual_kernel_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_, actual_bias_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_outputs_, actual_outputs_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_kernel_divergence_, actual_kernel_divergence_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_divergence_, actual_bias_divergence_, + rtol=1e-6, atol=0.) + + self.assertAllEqual( + [[kernel_posterior, kernel_prior, kernel_posterior.result_sample]], + kernel_divergence.args) + + self.assertAllEqual( + [[bias_posterior, bias_prior, bias_posterior.result_sample]], + bias_divergence.args) + + def testVariationalLocal(self): + batch_size, in_size, out_size = 2, 3, 4 + with self.test_session() as sess: + seed = Counter() + inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) + + kernel_size = [in_size, out_size] + kernel_posterior = MockDistribution( + loc=random_ops.random_uniform(kernel_size, seed=seed()), + scale=random_ops.random_uniform(kernel_size, seed=seed()), + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_divergence = MockKLDivergence( + result=random_ops.random_uniform(kernel_size, seed=seed())) + + bias_size = [out_size] + bias_posterior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_divergence = MockKLDivergence( + result=random_ops.random_uniform(bias_size, seed=seed())) + + expected_kernel_posterior_affine = normal_lib.Normal( + loc=math_ops.matmul(inputs, kernel_posterior.result_loc), + scale=math_ops.matmul( + inputs**2., kernel_posterior.result_scale**2)**0.5) + expected_kernel_posterior_affine_tensor = ( + expected_kernel_posterior_affine.sample(seed=42)) + expected_outputs = (expected_kernel_posterior_affine_tensor + + bias_posterior.result_sample) + + dense_vi = prob_layers_lib.DenseVariational( + units=2, + kernel_use_local_reparameterization=True, + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + kernel_prior_fn=lambda *args: kernel_prior, + kernel_divergence_fn=kernel_divergence, + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + bias_prior_fn=lambda *args: bias_prior, + bias_divergence_fn=bias_divergence) + + outputs = dense_vi(inputs) + + kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + + [ + expected_outputs_, actual_outputs_, + expected_kernel_divergence_, actual_kernel_divergence_, + expected_bias_, actual_bias_, + expected_bias_divergence_, actual_bias_divergence_, + ] = sess.run([ + expected_outputs, outputs, + kernel_divergence.result, kl_penalty[0], + bias_posterior.result_sample, dense_vi.bias.posterior_tensor, + bias_divergence.result, kl_penalty[1], + ]) + + self.assertAllClose( + expected_bias_, actual_bias_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_outputs_, actual_outputs_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_kernel_divergence_, actual_kernel_divergence_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_divergence_, actual_bias_divergence_, + rtol=1e-6, atol=0.) + + self.assertAllEqual( + [[kernel_posterior, kernel_prior, None]], + kernel_divergence.args) + + self.assertAllEqual( + [[bias_posterior, bias_prior, bias_posterior.result_sample]], + bias_divergence.args) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py new file mode 100644 index 00000000000..66793383fdd --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py @@ -0,0 +1,209 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional test for GradientDescent.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import math +from tensorflow.contrib.bayesflow.python.ops.optimizers import SGLDOptimizer +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class SGLDOptimizerTest(test.TestCase): + + def testBasic(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + decay_rate = 0.53 + sgd_op = SGLDOptimizer( + 3.0, preconditioner_decay_rate=decay_rate).apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + + (1 - decay_rate) * 0.1**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) + grads_scaled = (0.5 * 0.01 / math.sqrt( + decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) + + def testBasicMultiInstance(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + vara = variables.Variable([1.1, 2.1], dtype=dtype) + varb = variables.Variable([3.0, 4.0], dtype=dtype) + gradsa = constant_op.constant([0.1, 0.1], dtype=dtype) + gradsb = constant_op.constant([0.01, 0.01], dtype=dtype) + decay_rate = 0.5 + sgd_optimizer = SGLDOptimizer(3.0, preconditioner_decay_rate=decay_rate) + sgd_op = sgd_optimizer.apply_gradients( + zip([grads0, grads1], [var0, var1])) + sgd_optimizer2 = SGLDOptimizer( + 3.0, preconditioner_decay_rate=decay_rate) + sgd_op2 = sgd_optimizer2.apply_gradients( + zip([gradsa, gradsb], [vara, varb])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + self.assertAllCloseAccordingToType([1.1, 2.1], vara.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], varb.eval()) + + # Run 1 step of sgd + sgd_op.run() + sgd_op2.run() + # Validate updated params + grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + + (1 - decay_rate) * 0.1**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) + self.assertAllCloseAccordingToType( + [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], vara.eval()) + + grads_scaled = (0.5 * 0.01 / math.sqrt( + decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) + self.assertAllCloseAccordingToType( + [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], varb.eval()) + self.assertNotEqual(sgd_optimizer.variable_scope, + sgd_optimizer2.variable_scope) + self.assertNotEqual(sgd_optimizer.variable_scope.name, + sgd_optimizer2.variable_scope.name) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + lrate = constant_op.constant(3.0) + decay_rate = 0.5 + sgd_op = SGLDOptimizer( + lrate, preconditioner_decay_rate=constant_op.constant( + decay_rate)).apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + + (1 - decay_rate) * 0.1**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) + grads_scaled = (0.5 * 0.01 / math.sqrt( + decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) + + def testGradWrtRef(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + opt = SGLDOptimizer(3.0) + values = [1.0, 3.0] + vars_ = [variables.Variable([v], dtype=dtype) for v in values] + grads_and_vars = opt.compute_gradients(vars_[0] + vars_[1], vars_) + variables.global_variables_initializer().run() + for grad, _ in grads_and_vars: + self.assertAllCloseAccordingToType([1.0], grad.eval()) + + def testWithGlobalStep(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + global_step = variables.Variable(0, trainable=False) + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + decay_rate = 0.1 + sgd_op = SGLDOptimizer( + 3.0, preconditioner_decay_rate=decay_rate).apply_gradients( + zip([grads0, grads1], [var0, var1]), global_step=global_step) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + + # Validate updated params and global_step + grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + + (1 - decay_rate) * 0.1**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) + grads_scaled = (0.5 * 0.01 / math.sqrt( + decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) + self.assertAllCloseAccordingToType(1, global_step.eval()) + + def testSparseBasic(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([[1.1], [2.1]], dtype=dtype) + var1 = variables.Variable([[3.0], [4.0]], dtype=dtype) + grads0 = ops.IndexedSlices( + constant_op.constant([0.1], shape=[1, 1], dtype=dtype), + constant_op.constant([0]), constant_op.constant([2, 1])) + grads1 = ops.IndexedSlices( + constant_op.constant([0.01], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), constant_op.constant([2, 1])) + decay_rate = 0.9 + sgd_op = SGLDOptimizer( + 3.0, preconditioner_decay_rate=decay_rate).apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.1], [2.1]], var0.eval()) + self.assertAllCloseAccordingToType([[3.0], [4.0]], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + + (1 - decay_rate) * 0.1**2 + 1e-8)) + self.assertAllCloseAccordingToType([[1.1 - 3.0 * grads_scaled], [2.1]], + var0.eval()) + grads_scaled = (0.5 * 0.01 / math.sqrt( + decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [[3.0 - 3.0 * 0], [4.0 - 3.0 * grads_scaled]], var1.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/layers.py b/tensorflow/contrib/bayesflow/python/ops/layers.py new file mode 100644 index 00000000000..dcead38af82 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/layers.py @@ -0,0 +1,37 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Probabilistic neural layers. + +See ${python/contrib.bayesflow.layers}. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.bayesflow.python.ops.layers_dense_variational_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'DenseVariational', + 'dense_variational', + 'default_loc_scale_fn', + 'default_mean_field_normal_fn', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py new file mode 100644 index 00000000000..b05ce0ffc1d --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py @@ -0,0 +1,797 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Dense Bayesian layer using KL-divergence based variational inference. + +@@DenseVariational +@@dense_variational + +@@default_loc_scale_fn +@@default_mean_field_normal_fn +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import base as layers_lib +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import standard_ops +from tensorflow.python.ops.distributions import kullback_leibler as kl_lib +from tensorflow.python.ops.distributions import normal as normal_lib + + +__all__ = [ + "DenseVariational", + "dense_variational", + "default_loc_scale_fn", + "default_mean_field_normal_fn", +] + + +def default_loc_scale_fn( + is_singular=False, + loc_initializer=init_ops.random_normal_initializer(stddev=0.1), + untransformed_scale_initializer=init_ops.random_normal_initializer( + mean=-3., stddev=0.1), + loc_regularizer=None, + untransformed_scale_regularizer=None, + loc_constraint=None, + untransformed_scale_constraint=None): + """Makes closure which creates `loc`, `scale` params from `tf.get_variable`. + + This function produces a closure which produces `loc`, `scale` using + `tf.get_variable`. The closure accepts the following arguments: + + dtype: Type of parameter's event. + shape: Python `list`-like representing the parameter's event shape. + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Args: + is_singular: Python `bool` indicating if `scale is None`. Default: `False`. + loc_initializer: Initializer function for the `loc` parameters. + The default is `tf.random_normal_initializer(mean=0., stddev=0.1)`. + untransformed_scale_initializer: Initializer function for the `scale` + parameters. Default value: `tf.random_normal_initializer(mean=-3., + stddev=0.1)`. This implies the softplus transformed result has mean + approximately `0.05` and std. deviation approximately `0.005`. + loc_regularizer: Regularizer function for the `loc` parameters. + The default (`None`) is to use the `tf.get_variable` default. + untransformed_scale_regularizer: Regularizer function for the `scale` + parameters. The default (`None`) is to use the `tf.get_variable` default. + loc_constraint: An optional projection function to be applied to the + loc after being updated by an `Optimizer`. The function must take as input + the unprojected variable and must return the projected variable (which + must have the same shape). Constraints are not safe to use when doing + asynchronous distributed training. + The default (`None`) is to use the `tf.get_variable` default. + untransformed_scale_constraint: An optional projection function to be + applied to the `scale` parameters after being updated by an `Optimizer` + (e.g. used to implement norm constraints or value constraints). The + function must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are not + safe to use when doing asynchronous distributed training. The default + (`None`) is to use the `tf.get_variable` default. + + Returns: + default_loc_scale_fn: Python `callable` which instantiates `loc`, `scale` + parameters from args: `dtype, shape, name, trainable, add_variable_fn`. + """ + def _fn(dtype, shape, name, trainable, add_variable_fn): + """Creates `loc`, `scale` parameters.""" + loc = add_variable_fn( + name=name + "_loc", + shape=shape, + initializer=loc_initializer, + regularizer=loc_regularizer, + constraint=loc_constraint, + dtype=dtype, + trainable=trainable) + if is_singular: + return loc, None + untransformed_scale = add_variable_fn( + name=name + "_untransformed_scale", + shape=shape, + initializer=untransformed_scale_initializer, + regularizer=untransformed_scale_regularizer, + constraint=untransformed_scale_constraint, + dtype=dtype, + trainable=trainable) + scale = (np.finfo(dtype.as_numpy_dtype).eps + + nn_ops.softplus(untransformed_scale)) + return loc, scale + return _fn + + +def default_mean_field_normal_fn( + is_singular=False, + loc_initializer=None, + untransformed_scale_initializer=None, + loc_regularizer=None, + untransformed_scale_regularizer=None, + loc_constraint=None, + untransformed_scale_constraint=None): + """Creates a function to build Normal distributions with trainable params. + + This function produces a closure which produces `tf.distributions.Normal` + parameterized by a loc` and `scale` each created using `tf.get_variable`. The + produced closure accepts the following arguments: + + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + shape: Python `list`-like representing the parameter's event shape. + dtype: Type of parameter's event. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Args: + is_singular: Python `bool` if `True`, forces the special case limit of + `scale->0`, i.e., a `Deterministic` distribution. + loc_initializer: Initializer function for the `loc` parameters. + If `None` (default), values are initialized using the default + initializer used by `tf.get_variable`. + untransformed_scale_initializer: Initializer function for the `scale` + parameters. If `None` (default), values are initialized using the default + initializer used by `tf.get_variable`. + loc_regularizer: Regularizer function for the `loc` parameters. + untransformed_scale_regularizer: Regularizer function for the `scale` + parameters. + loc_constraint: An optional projection function to be applied to the + loc after being updated by an `Optimizer`. The function must take as input + the unprojected variable and must return the projected variable (which + must have the same shape). Constraints are not safe to use when doing + asynchronous distributed training. + untransformed_scale_constraint: An optional projection function to be + applied to the `scale` parameters after being updated by an `Optimizer` + (e.g. used to implement norm constraints or value constraints). The + function must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are not + safe to use when doing asynchronous distributed training. + + Returns: + make_normal_fn: Python `callable` which creates a `tf.distributions.Normal` + using from args: `dtype, shape, name, trainable, add_variable_fn`. + """ + loc_scale_fn_ = default_loc_scale_fn( + is_singular, + loc_initializer, + untransformed_scale_initializer, + loc_regularizer, + untransformed_scale_regularizer, + loc_constraint, + untransformed_scale_constraint) + def _fn(dtype, shape, name, trainable, add_variable_fn): + """Creates a batch of `Deterministic` or `Normal` distributions.""" + loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn) + if scale is None: + return deterministic_lib.Deterministic(loc=loc) + return normal_lib.Normal(loc=loc, scale=scale) + return _fn + + +class DenseVariational(layers_lib.Layer): + """Densely-connected variational class. + + This layer implements the Bayesian variational inference analogue to: + `outputs = activation(matmul(inputs, kernel) + bias)` + by assuming the `kernel` and/or the `bias` are random variables. + + The layer implements a stochastic dense calculation by making a Monte Carlo + approximation of a [variational Bayesian method based on KL divergence]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e., + + ```none + -log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw + = -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw + <= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's + = E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)] + ~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m } + + KL[q(W|x), p(W)] + ``` + + where `W` denotes the (independent) `kernel` and `bias` random variables, `w` + is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`, + and `~=` denotes an approximation which becomes exact as `m->inf`. The above + bound is sometimes referred to as the negative Evidence Lower BOund or + negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this + layer is appropriate to use when the final loss is a negative log-likelihood. + + The Monte-Carlo sum portion is used for the feed-forward calculation of the + DNN. The KL divergence portion can be added to the final loss via: + `loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + random variables (which together comprise `W`). + + Args: + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_use_local_reparameterization: Python `bool` indicating whether + `kernel` calculation should employ the Local Reparameterization Trick. + When `True`, `kernel_posterior_fn` must create an instance of + `tf.distributions.Normal`. + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Properties: + units: Python integer, dimensionality of the output space. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_use_local_reparameterization: Python `bool` indicating whether + `kernel` calculation should employ the Local Reparameterization Trick. + kernel: `VariationalKernelParamater` instance containing all `kernel` + related properties and `callable`s. + bias: `VariationalParameter` instance containing all `kernel` + related properties and `callable`s. + """ + + def __init__( + self, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_use_local_reparameterization=True, + kernel_posterior_fn=default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=default_mean_field_normal_fn(is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(DenseVariational, self).__init__( + trainable=trainable, + name=name, + activity_regularizer=activity_regularizer, + **kwargs) + self._units = units + self._activation = activation + self._input_spec = layers_lib.InputSpec(min_ndim=2) + self._kernel_use_local_reparameterization = ( + kernel_use_local_reparameterization) + self._kernel = VariationalKernelParameter( + kernel_posterior_fn, + kernel_posterior_tensor_fn, + kernel_prior_fn, + kernel_divergence_fn) + self._bias = VariationalParameter( + bias_posterior_fn, + bias_posterior_tensor_fn, + bias_prior_fn, + bias_divergence_fn) + + @property + def units(self): + return self._units + + @property + def activation(self): + return self._activation + + @property + def input_spec(self): + return self._input_spec + + @input_spec.setter + def input_spec(self, value): + self._input_spec = value + + @property + def kernel_use_local_reparameterization(self): + return self._kernel_use_local_reparameterization + + @property + def kernel(self): + return self._kernel + + @property + def bias(self): + return self._bias + + def build(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + in_size = input_shape.with_rank_at_least(2)[-1].value + if in_size is None: + raise ValueError("The last dimension of the inputs to `Dense` " + "should be defined. Found `None`.") + self._input_spec = layers_lib.InputSpec(min_ndim=2, axes={-1: in_size}) + dtype = dtypes.as_dtype(self.dtype) + + # Must have a posterior kernel. + self.kernel.posterior = self.kernel.posterior_fn( + dtype, [in_size, self.units], "kernel_posterior", + self.trainable, self.add_variable) + + if self.kernel.prior_fn is None: + self.kernel_prior = None + else: + self.kernel.prior = self.kernel.prior_fn( + dtype, [in_size, self.units], "kernel_prior", + self.trainable, self.add_variable) + self._built_kernel_divergence = False + + if self.bias.posterior_fn is None: + self.bias.posterior = None + else: + self.bias.posterior = self.bias.posterior_fn( + dtype, [self.units], "bias_posterior", + self.trainable, self.add_variable) + + if self.bias.prior_fn is None: + self.bias.prior = None + else: + self.bias.prior = self.bias.prior_fn( + dtype, [self.units], "bias_prior", + self.trainable, self.add_variable) + self._built_bias_divergence = False + + self.built = True + + def call(self, inputs): + inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) + + outputs = self._apply_variational_kernel(inputs) + outputs = self._apply_variational_bias(outputs) + if self.activation is not None: + outputs = self.activation(outputs) # pylint: disable=not-callable + if not self._built_kernel_divergence: + self._apply_divergence(self.kernel, name="divergence_kernel") + self._built_kernel_divergence = True + if not self._built_bias_divergence: + self._apply_divergence(self.bias, name="divergence_bias") + self._built_bias_divergence = True + return outputs + + def _apply_variational_kernel(self, inputs): + if not self.kernel_use_local_reparameterization: + self.kernel.posterior_tensor = self.kernel.posterior_tensor_fn( + self.kernel.posterior) + self.kernel.posterior_affine = None + self.kernel.posterior_affine_tensor = None + return self._matmul(inputs, self.kernel.posterior_tensor) + if not isinstance(self.kernel.posterior, normal_lib.Normal): + raise TypeError("`kernel_use_local_reparameterization=True` requires " + "`kernel_posterior_fn` produce an instance of " + "`tf.distributions.Normal` (saw: \"{}\").".format( + type(self.kernel.posterior).__name__)) + self.kernel.posterior_affine = normal_lib.Normal( + loc=self._matmul(inputs, self.kernel.posterior.loc), + scale=standard_ops.sqrt(self._matmul( + standard_ops.square(inputs), + standard_ops.square(self.kernel.posterior.scale)))) + self.kernel.posterior_affine_tensor = ( + self.kernel.posterior_tensor_fn(self.kernel.posterior_affine)) + self.kernel.posterior_tensor = None + return self.kernel.posterior_affine_tensor + + def _apply_variational_bias(self, inputs): + if self.bias.posterior is None: + self.bias.posterior_tensor = None + return inputs + self.bias.posterior_tensor = self.bias.posterior_tensor_fn( + self.bias.posterior) + return nn.bias_add(inputs, self.bias.posterior_tensor) + + def _apply_divergence(self, param, name): + if (param.divergence_fn is None or + param.posterior is None or + param.prior is None): + param.divergence = None + return + param.divergence = standard_ops.identity( + param.divergence_fn( + param.posterior, param.prior, param.posterior_tensor), + name=name) + self.add_loss(param.divergence) + + def _matmul(self, inputs, kernel): + if inputs.shape.ndims <= 2: + return standard_ops.matmul(inputs, kernel) + # To handle broadcasting, we must use `tensordot`. + return standard_ops.tensordot(inputs, kernel, axes=[[-1], [0]]) + + def _compute_output_shape(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape).with_rank_at_least(2) + if input_shape[-1].value is None: + raise ValueError( + "The innermost dimension of input_shape must be defined, " + "but saw: {}".format(input_shape)) + return input_shape[:-1].concatenate(self.units) + + +def dense_variational( + inputs, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_use_local_reparameterization=True, + kernel_posterior_fn=default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=default_mean_field_normal_fn(is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + reuse=None): + """Densely-connected variational layer. + + This layer implements the Bayesian variational inference analogue to: + `outputs = activation(matmul(inputs, kernel) + bias)` + by assuming the `kernel` and/or the `bias` are random variables. + + The layer implements a stochastic dense calculation by making a Monte Carlo + approximation of a [variational Bayesian method based on KL divergence]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e., + + ```none + -log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw + = -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw + <= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's + = E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)] + ~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m } + + KL[q(W|x), p(W)] + ``` + + where `W` denotes the (independent) `kernel` and `bias` random variables, `w` + is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`, + and `~=` denotes an approximation which becomes exact as `m->inf`. The above + bound is sometimes referred to as the negative Evidence Lower BOund or + negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this + layer is appropriate to use when the final loss is a negative log-likelihood. + + The Monte-Carlo sum portion is used for the feed-forward calculation of the + DNN. The KL divergence portion can be added to the final loss via: + `loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + random variables (which together comprise `W`). + + Args: + inputs: Tensor input. + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_use_local_reparameterization: Python `bool` indicating whether + `kernel` calculation should employ the Local Reparameterization Trick. + When `True`, `kernel_posterior_fn` must create an instance of + `tf.distributions.Normal`. + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Returns: + output: `Tensor` representing a the affine transformed input under a random + draw from the surrogate posterior distribution. + """ + layer = DenseVariational( + units, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_use_local_reparameterization=( + kernel_use_local_reparameterization), + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +class NotSet(object): + """Helper to track whether a `VariationalParameter` value has been set.""" + pass + + +class VariationalParameter(object): + """Struct-like container of variational parameter properties. + + A `VariationalParameter` is intitialized with Python `callable`s which set the + value of correspondingly named members. Corresponding values have "set once" + semantics, i.e., once set to any value they are immutable. + """ + + def __init__( + self, + posterior_fn, + posterior_tensor_fn, + prior_fn, + divergence_fn): + """Creates the `VariationalParameter` struct-like object. + + Args: + posterior_fn: Python `callable` which creates a + `tf.distribution.Distribution` like object representing the posterior + distribution. See `VariationalParameter.posterior_fn` for `callable`'s + required parameters. + posterior_tensor_fn: Python `callable` which computes a `Tensor` + which represents the `posterior`. + prior_fn: Python `callable` which creates a + `tf.distribution.Distribution` like object representing the prior + distribution. See `VariationalParameter.prior_fn` for `callable`'s + required parameters. + divergence_fn: Python `callable` which computes the KL divergence from + `posterior` to `prior`. See `VariationalParameter.divergence_fn` for + required `callable`'s parameters. + """ + self._posterior_fn = posterior_fn + self._posterior = NotSet() + self._posterior_tensor_fn = posterior_tensor_fn + self._posterior_tensor = NotSet() + self._prior_fn = prior_fn + self._prior = NotSet() + self._divergence_fn = divergence_fn + self._divergence = NotSet() + self._init_helper() + + @property + def posterior_fn(self): + """`callable` which creates `tf.distributions.Distribution`-like posterior. + + The `callable` must accept the following parameters: + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + shape: Python `list`-like representing the parameter's event shape. + dtype: Type of parameter's event. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Returns: + posterior_fn: The Python `callable` specified in `__init__`. + """ + return self._posterior_fn + + @property + def posterior(self): + """`tf.distributions.Distribution`-like instance representing posterior.""" + return self._posterior + + @posterior.setter + def posterior(self, value): + """One-time setter of the `posterior` distribution.""" + if not isinstance(self._posterior, NotSet): + raise ValueError("Cannot override already set attribute.") + self._posterior = value + + @property + def posterior_tensor_fn(self): + """Creates `Tensor` representing the `posterior` distribution. + + The `callable` must accept the following parameters: + posterior: `tf.distributions.Distribution`-like instance. + + Returns: + posterior_tensor_fn: The Python `callable` specified in + `__init__`. + """ + return self._posterior_tensor_fn + + @property + def posterior_tensor(self): + """`Tensor` representing the `posterior` distribution.""" + return self._posterior_tensor + + @posterior_tensor.setter + def posterior_tensor(self, value): + """One-time setter of the `posterior_tensor`.""" + if not isinstance(self._posterior_tensor, NotSet): + raise ValueError("Cannot override already set attribute.") + self._posterior_tensor = value + + @property + def prior_fn(self): + """`callable` which creates `tf.distributions.Distribution`-like prior. + + The `callable` must accept the following parameters: + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + shape: Python `list`-like representing the parameter's event shape. + dtype: Type of parameter's event. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Returns: + prior_fn: The Python `callable` specified in `__init__`. + """ + return self._prior_fn + + @property + def prior(self): + """`tf.distributions.Distribution`-like instance representing posterior.""" + return self._prior + + @prior.setter + def prior(self, value): + """One-time setter of the `prior` distribution.""" + if not isinstance(self._prior, NotSet): + raise ValueError("Cannot override already set attribute.") + self._prior = value + + @property + def divergence_fn(self): + """`callable` which computes KL-divergence `Tensor` from posterior to prior. + + The `callable` must accept the following parameters: + posterior: `tf.distributions.Distribution`-like instance. + prior: `tf.distributions.Distribution`-like instance. + posterior_tensor: `Tensor` representing value of posterior. + + Returns: + divergence_fn: The Python `callable` specified in `__init__`. + """ + return self._divergence_fn + + @property + def divergence(self): + """`Tensor` representing KL-divergence from posterior to prior.""" + return self._divergence + + @divergence.setter + def divergence(self, value): + """One-time setter of the `divergence`.""" + if not isinstance(self._divergence, NotSet): + raise ValueError("Cannot override already set attribute.") + self._divergence = value + + def _init_helper(self): + pass + + +class VariationalKernelParameter(VariationalParameter): + """Struct-like container of variational kernel properties. + + A `VariationalKernelParameter` is intitialized with Python `callable`s which + set the value of correspondingly named members. Corresponding values have "set + once" semantics, i.e., once set to any value they are immutable. + """ + + @property + def posterior_affine(self): + """`tf.distributions.Distribution` affine transformed posterior.""" + return self._posterior_affine + + @posterior_affine.setter + def posterior_affine(self, value): + """One-time setter of `posterior_affine`.""" + if not isinstance(self._posterior_affine, NotSet): + raise ValueError("Cannot override already set attribute.") + self._posterior_affine = value + + @property + def posterior_affine_tensor(self): + """`Tensor` representing the `posterior_affine` distribution.""" + return self._posterior_affine_tensor + + @posterior_affine_tensor.setter + def posterior_affine_tensor(self, value): + """One-time setter of the `posterior_affine_tensor`.""" + if not isinstance(self._posterior_affine_tensor, NotSet): + raise ValueError("Cannot override already set attribute.") + self._posterior_affine_tensor = value + + def _init_helper(self): + self._posterior_affine = NotSet() + self._posterior_affine_tensor = NotSet() diff --git a/tensorflow/contrib/bayesflow/python/ops/optimizers.py b/tensorflow/contrib/bayesflow/python/ops/optimizers.py new file mode 100644 index 00000000000..ee32e6b5c3d --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/optimizers.py @@ -0,0 +1,34 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Probabilistic optimizer modules. + +See ${python/contrib.bayesflow.optimizers}. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.bayesflow.python.ops.sgld_optimizer import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'SGLDOptimizer', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py b/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py new file mode 100644 index 00000000000..5d36ea7a2b5 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py @@ -0,0 +1,216 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An optimizer module for stochastic gradient Langevin dynamics.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope as varscope_ops +from tensorflow.python.training import optimizer +from tensorflow.python.training import training_ops + + +class SGLDOptimizer(optimizer.Optimizer): + """An optimizer module for stochastic gradient Langevin dynamics. + + This implements the preconditioned Stochastic Gradient Langevin Dynamics + optimizer [1]. The optimization variable is regarded as a sample from the + posterior under Stochastic Gradient Langevin Dynamics with noise rescaled in + each dimension according to RMSProp [2]. + + Note: If a prior is included in the loss, it should be scaled by + `1/num_pseudo_batches`, where num_pseudo_batches is the number of minibatches + in the data. I.e., it should be divided by the `num_pseudo_batches` term + described below. + + [1]: "Preconditioned Stochastic Gradient Langevin Dynamics for Deep Neural + Networks." Chunyuan Li, Changyou Chen, David Carlson, Lawrence Carin. + ArXiv:1512.07666, 2015. https://arxiv.org/abs/1512.07666 + [2]: http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf + + Args: + learning_rate: Scalar `float`-like `Tensor`. The base learning rate for the + optimizer. Must be tuned to the specific function being minimized. + preconditioner_decay_rate: Scalar `float`-like `Tensor`. The exponential + decay rate of the rescaling of the preconditioner (RMSprop). (This is + "alpha" in [1]). Should be smaller than but nearly `1` to approximate + sampling from the posterior. (Default: `0.95`) + num_pseudo_batches: Scalar `int`-like `Tensor`. The effective number of + minibatches in the data set. Trades off noise and prior with the SGD + likelihood term. Note: Assumes the loss is taken as the mean over a + minibatch. Otherwise if the sum was taken, divide this number by the + batch size. (Default: `1`) + burnin: Scalar `int`-like `Tensor`. The number of iterations to collect + gradient statistics to update the preconditioner before starting to draw + noisy samples. (Default: `25`) + diagonal_bias: Scalar `float`-like `Tensor`. Term added to the diagonal of + the preconditioner to prevent the preconditioner from degenerating. + (Default: `1e-8`) + name: Python `str` describing ops managed by this function. + (Default: `"SGLDOptimizer"`) + variable_scope: Variable scope used for calls to `tf.get_variable`. + If `None`, a new variable scope is created using name + `ops.get_default_graph().unique_name(name or default_name)`. + + Raises: + InvalidArgumentError: If preconditioner_decay_rate is a `Tensor` not in + `(0,1]`. + """ + + def __init__(self, + learning_rate, + preconditioner_decay_rate=0.95, + num_pseudo_batches=1, + burnin=25, + diagonal_bias=1e-8, + name=None, + variable_scope=None): + default_name = 'SGLDOptimizer' + with ops.name_scope(name, default_name, [ + learning_rate, preconditioner_decay_rate, num_pseudo_batches, burnin, + diagonal_bias + ]): + if variable_scope is None: + var_scope_name = ops.get_default_graph().unique_name( + name or default_name) + with varscope_ops.variable_scope(var_scope_name) as scope: + self._variable_scope = scope + else: + self._variable_scope = variable_scope + + self._preconditioner_decay_rate = ops.convert_to_tensor( + preconditioner_decay_rate, name='preconditioner_decay_rate') + self._num_pseudo_batches = ops.convert_to_tensor( + num_pseudo_batches, name='num_pseudo_batches') + self._burnin = ops.convert_to_tensor(burnin, name='burnin') + self._diagonal_bias = ops.convert_to_tensor( + diagonal_bias, name='diagonal_bias') + self._learning_rate = ops.convert_to_tensor( + learning_rate, name='learning_rate') + + with varscope_ops.variable_scope(self._variable_scope): + self._counter = varscope_ops.get_variable( + 'counter', initializer=0, trainable=False) + + self._preconditioner_decay_rate = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._preconditioner_decay_rate, + message='`preconditioner_decay_rate` must be non-negative'), + check_ops.assert_less_equal( + self._preconditioner_decay_rate, + 1., + message='`preconditioner_decay_rate` must be at most 1.'), + ], self._preconditioner_decay_rate) + + self._num_pseudo_batches = control_flow_ops.with_dependencies([ + check_ops.assert_greater( + self._num_pseudo_batches, + 0, + message='`num_pseudo_batches` must be greater than zero') + ], self._num_pseudo_batches) + + self._burnin = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._burnin, message='`burnin` must be non-negative'), + check_ops.assert_integer( + self._burnin, message='`burnin` must be an integer') + ], self._burnin) + + self._diagonal_bias = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._diagonal_bias, + message='`diagonal_bias` must be non-negative') + ], self._diagonal_bias) + + super(SGLDOptimizer, self).__init__(use_locking=False, + name=name or default_name) + + def _create_slots(self, var_list): + for v in var_list: + init_rms = init_ops.ones_initializer(dtype=v.dtype) + self._get_or_make_slot_with_initializer(v, init_rms, v.get_shape(), + v.dtype, 'rms', self._name) + + def _prepare(self): + # We need to put the conversion and check here because a user will likely + # want to decay the learning rate dynamically. + self._learning_rate_tensor = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._learning_rate, message='`learning_rate` must be non-negative') + ], ops.convert_to_tensor(self._learning_rate, name='learning_rate_tensor')) + self._decay_tensor = ops.convert_to_tensor( + self._preconditioner_decay_rate, name='preconditioner_decay_rate') + + super(SGLDOptimizer, self)._prepare() + + def _apply_dense(self, grad, var): + rms = self.get_slot(var, 'rms') + + with ops.control_dependencies([ + self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor, + var.dtype.base_dtype))]): + new_grad = self._apply_noisy_update(rms, grad) + + return training_ops.apply_gradient_descent( + var, + math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), + new_grad, + use_locking=self._use_locking).op + + def _apply_sparse(self, grad, var): + rms = self.get_slot(var, 'rms') + + with ops.control_dependencies([ + self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor, + var.dtype.base_dtype))]): + new_grad = self._apply_noisy_update(rms, grad) + + return training_ops.apply_gradient_descent( + var, + math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), + new_grad, + use_locking=self._use_locking).op + + @property + def variable_scope(self): + """Variable scope of all calls to `tf.get_variable`.""" + return self._variable_scope + + def _apply_noisy_update(self, mom, grad): + # Compute and apply the gradient update following + # preconditioned Langevin dynamics + stddev = array_ops.where( + array_ops.squeeze(self._counter > self._burnin), + math_ops.cast(math_ops.rsqrt(self._learning_rate), grad.dtype), + array_ops.zeros([], grad.dtype)) + + preconditioner = math_ops.rsqrt( + mom + math_ops.cast(self._diagonal_bias, grad.dtype)) + return ( + 0.5 * preconditioner * grad * math_ops.cast(self._num_pseudo_batches, + grad.dtype) + + random_ops.random_normal(array_ops.shape(grad), 1.0, dtype=grad.dtype) * + stddev * math_ops.sqrt(preconditioner)) + + def _update_momentum(self, mom, grad, decay): + # Keep an exponentially weighted moving average of squared gradients. + # Not thread safe + return mom.assign_add((1.0 - decay) * (math_ops.square(grad) - mom)) diff --git a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc index 766982b4f20..f8086b0c2bb 100644 --- a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc @@ -63,19 +63,26 @@ const char* kPredictionsTensorName = "predictions"; void CalculateTreesToInclude( const boosted_trees::trees::DecisionTreeEnsembleConfig& config, const std::vector& trees_to_drop, const int32 num_trees, - const bool only_finalized, std::vector* trees_to_include) { + const bool only_finalized, const bool center_bias, + std::vector* trees_to_include) { trees_to_include->reserve(num_trees - trees_to_drop.size()); int32 index = 0; // This assumes that trees_to_drop is a sorted list of tree ids. for (int32 tree = 0; tree < num_trees; ++tree) { - if ((!trees_to_drop.empty() && index < trees_to_drop.size() && - trees_to_drop[index] == tree) || - (only_finalized && config.tree_metadata_size() > 0 && - !config.tree_metadata(tree).is_finalized())) { + // Skip the tree if tree is in the list of trees_to_drop. + if (!trees_to_drop.empty() && index < trees_to_drop.size() && + trees_to_drop[index] == tree) { ++index; continue; } + // Or skip if the tree is not finalized and only_finalized is set, + // with the exception of centering bias. + if (only_finalized && !(center_bias && tree == 0) && + config.tree_metadata_size() > 0 && + !config.tree_metadata(tree).is_finalized()) { + continue; + } trees_to_include->push_back(tree); } } @@ -250,7 +257,7 @@ class GradientTreesPredictionOp : public OpKernel { CalculateTreesToInclude( ensemble_resource->decision_tree_ensemble(), dropped_trees, ensemble_resource->decision_tree_ensemble().trees_size(), - only_finalized_trees_, &trees_to_include); + only_finalized_trees_, center_bias_, &trees_to_include); // Allocate output predictions matrix. Tensor* output_predictions_t = nullptr; diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index 2a5c7949f2d..c77d90e243c 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -237,6 +237,7 @@ class CenterTreeEnsembleBiasOp : public OpKernel { VLOG(1) << "Continuing to center bias, delta=" << total_delta; } else { VLOG(1) << "Done centering bias, delta=" << total_delta; + ensemble_resource->LastTreeMetadata()->set_is_finalized(true); } Tensor* continue_centering_t = nullptr; OP_REQUIRES_OK( @@ -260,7 +261,6 @@ class CenterTreeEnsembleBiasOp : public OpKernel { for (size_t idx = 0; idx < logits_dimension; ++idx) { leaf->mutable_vector()->add_value(0.0); } - ensemble_resource->LastTreeMetadata()->set_is_finalized(true); return leaf; } else if (num_trees == 1) { // Confirms that the only tree is a bias and returns its leaf. diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py index f0413fee5a8..c2e65b643df 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py @@ -181,7 +181,6 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase): tree_weights: 1.0 tree_metadata { num_layers_grown: 1 - is_finalized: true } growing_metadata { num_trees_attempted: 1 @@ -189,7 +188,7 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase): } """ self.assertEqual(new_stamp, 1) - self.assertEqual(stats.num_trees, 1) + self.assertEqual(stats.num_trees, 0) self.assertEqual(stats.num_layers, 1) self.assertEqual(stats.active_tree, 1) self.assertEqual(stats.active_layer, 1) @@ -231,7 +230,6 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase): tree_weights: 1.0 tree_metadata { num_layers_grown: 1 - is_finalized: true } growing_metadata { num_trees_attempted: 1 @@ -239,7 +237,7 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase): } """ self.assertEqual(new_stamp, 2) - self.assertEqual(stats.num_trees, 1) + self.assertEqual(stats.num_trees, 0) self.assertEqual(stats.num_layers, 1) self.assertEqual(stats.active_tree, 1) self.assertEqual(stats.active_layer, 1) diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 3e2a858b8ae..61b3fd715dd 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -238,6 +238,7 @@ add_python_module("tensorflow/python/keras/datasets") add_python_module("tensorflow/python/keras/datasets/boston_housing") add_python_module("tensorflow/python/keras/datasets/cifar10") add_python_module("tensorflow/python/keras/datasets/cifar100") +add_python_module("tensorflow/python/keras/datasets/fashion_mnist") add_python_module("tensorflow/python/keras/datasets/imdb") add_python_module("tensorflow/python/keras/datasets/mnist") add_python_module("tensorflow/python/keras/datasets/reuters") diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index 9174c5eb989..964ec754413 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -23,7 +23,6 @@ import itertools import numpy as np from tensorflow.contrib.crf.python.ops import crf -from tensorflow.python.framework import dtypes from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -58,18 +57,19 @@ class CrfTest(test.TestCase): def testCrfUnaryScore(self): inputs = np.array( [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) - tag_indices = np.array([1, 2, 1, 0], dtype=np.int32) - sequence_lengths = np.array(3, dtype=np.int32) - with self.test_session() as sess: - unary_score = crf.crf_unary_score( - tag_indices=array_ops.expand_dims(tag_indices, 0), - sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), - inputs=array_ops.expand_dims(inputs, 0)) - unary_score = array_ops.squeeze(unary_score, [0]) - tf_unary_score = sess.run(unary_score) - expected_unary_score = sum(inputs[i][tag_indices[i]] - for i in range(sequence_lengths)) - self.assertAllClose(tf_unary_score, expected_unary_score) + for dtype in (np.int32, np.int64): + tag_indices = np.array([1, 2, 1, 0], dtype=dtype) + sequence_lengths = np.array(3, dtype=np.int32) + with self.test_session() as sess: + unary_score = crf.crf_unary_score( + tag_indices=array_ops.expand_dims(tag_indices, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + inputs=array_ops.expand_dims(inputs, 0)) + unary_score = array_ops.squeeze(unary_score, [0]) + tf_unary_score = sess.run(unary_score) + expected_unary_score = sum(inputs[i][tag_indices[i]] + for i in range(sequence_lengths)) + self.assertAllClose(tf_unary_score, expected_unary_score) def testCrfBinaryScore(self): tag_indices = np.array([1, 2, 1, 0], dtype=np.int32) diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index c8adb0369b9..8b621732c13 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -193,6 +193,9 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs): offsets = array_ops.expand_dims( math_ops.range(batch_size) * max_seq_len * num_tags, 1) offsets += array_ops.expand_dims(math_ops.range(max_seq_len) * num_tags, 0) + # Use int32 or int64 based on tag_indices' dtype. + if tag_indices.dtype == dtypes.int64: + offsets = math_ops.to_int64(offsets) flattened_tag_indices = array_ops.reshape(offsets + tag_indices, [-1]) unary_scores = array_ops.reshape( @@ -305,7 +308,7 @@ def viterbi_decode(score, transition_params): Returns: viterbi: A [seq_len] list of integers containing the highest scoring tag - indicies. + indices. viterbi_score: A float containing the score for the Viterbi sequence. """ trellis = np.zeros_like(score) @@ -385,7 +388,7 @@ class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell): """Initialize the CrfDecodeBackwardRnnCell. Args: - num_tags: An integer. + num_tags: An integer. The number of tags. """ self._num_tags = num_tags @@ -434,9 +437,9 @@ def crf_decode(potentials, transition_params, sequence_length): sequence_length: A [batch_size] vector of true sequence lengths. Returns: - decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`. - Contains the highest scoring tag indicies. - best_score: A [batch_size] vector, containing the score of `decode_tags`. + decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32. + Contains the highest scoring tag indices. + best_score: A [batch_size] tensor, containing the score of decode_tags. """ # For simplicity, in shape comments, denote: # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 78894c98556..5b635226f09 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -270,6 +270,7 @@ py_test( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -277,15 +278,20 @@ py_test( "//tensorflow/python:data_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", "//tensorflow/python:functional_ops", "//tensorflow/python:io_ops", "//tensorflow/python:lookup_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:platform", "//tensorflow/python:random_ops", "//tensorflow/python:script_ops", "//tensorflow/python:string_ops", + "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:iterator_ops", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index 8ccf92c17aa..d8e7f9d5933 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -25,9 +25,13 @@ import numpy as np from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops @@ -40,7 +44,10 @@ from tensorflow.python.ops import script_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import compat @@ -668,6 +675,515 @@ class MapDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testCaptureResourceInMapFn(self): + + def _build_ds(iterator): + + def _map_fn(x): + get_next = iterator.get_next() + return x * get_next + + return dataset_ops.Dataset.range(10).map(_map_fn) + + def _build_graph(): + captured_iterator = dataset_ops.Dataset.range( + 10).make_initializable_iterator() + ds = _build_ds(captured_iterator) + iterator = ds.make_initializable_iterator() + init_op = iterator.initializer + return captured_iterator.initializer, init_op + + with ops.Graph().as_default() as g: + captured_init_op, init_op = _build_graph() + with self.test_session(graph=g) as sess: + sess.run(captured_init_op) + with self.assertRaises(errors.UnimplementedError): + # CapturedFunction does not support capturing IteratorResource. + sess.run(init_op) + + +class MapDatasetSerializationTest(test.TestCase): + + def setUp(self): + self._tensor_slice_len = 7 + self._num_epochs = 14 + self._num_outputs = self._tensor_slice_len * self._num_epochs + + def tearDown(self): + # Remove all checkpoint files. + prefix = self._ckpt_path() + pattern = prefix + "*" + files = gfile.Glob(pattern) + map(gfile.Remove, files) + + def _build_ds(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(self._num_epochs)) + + def _build_graph(self, multiplier=37.0, build_saveable=True): + ds = self._build_ds(multiplier) + iterator = ds.make_initializable_iterator() + + if build_saveable: + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + init_op = iterator.initializer + get_next = iterator.get_next() + self._add_iterator_ops_to_collection(init_op, get_next) + saver = saver_lib.Saver(allow_empty=True) + return init_op, get_next, saver + + def _build_empty_graph(self, output_types, output_shapes): + iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + saver = saver_lib.Saver() + get_next = iterator.get_next() + return get_next, saver + + def _add_iterator_ops_to_collection(self, init_op, get_next): + ops.add_to_collection("iterator_ops", init_op) + ops.add_to_collection("iterator_ops", get_next[0]) + ops.add_to_collection("iterator_ops", get_next[1]) + ops.add_to_collection("iterator_ops", get_next[2]) + + def _get_iterator_ops_from_collection(self): + init_op, get_next_1, get_next_2, get_next_3 = ops.get_collection( + "iterator_ops") + return init_op, (get_next_1, get_next_2, get_next_3) + + def _ckpt_path(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _latest_ckpt(self): + return saver_lib.latest_checkpoint(self.get_temp_dir()) + + def _save(self, sess, saver): + saver.save(sess, self._ckpt_path()) + + def _restore(self, saver, sess): + saver.restore(sess, self._latest_ckpt()) + + def _import_meta_graph(self): + meta_file_path = self._ckpt_path() + ".meta" + return saver_lib.import_meta_graph(meta_file_path) + + def _testReadWithBreaks(self, break_points, init_before_restore=False): + expected = [] + actual = [] + # Generate the ground truth. + with ops.Graph().as_default() as g: + init_op, get_next_op, _ = self._build_graph() + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(self._num_outputs): + expected.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + # Run and checkpoint after first break_point. + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph() + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_points[0]): + actual.append(sess.run(get_next_op)) + self._save(sess, saver) + + # Load from checkpoint and continue running while stopping at each + # subsequent checkpoint. + for i in range(len(break_points)): + with ops.Graph().as_default() as g: + saver = self._import_meta_graph() + init_op, get_next_op = self._get_iterator_ops_from_collection() + with self.test_session(graph=g) as sess: + if init_before_restore: + sess.run(init_op) + self._restore(saver, sess) + start = break_points[i] + end = break_points[ + i + 1] if i < len(break_points) - 1 else self._num_outputs + for _ in range(end - start): + actual.append(sess.run(get_next_op)) + self._save(sess, saver) + if end == self._num_outputs: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + self._match(expected, actual) + + def _match(self, expected, actual): + self.assertEqual(len(expected), len(actual)) + for expected_tuple, actual_tuple in zip(expected, actual): + self.assertEqual(expected_tuple[0], actual_tuple[0]) + self.assertSequenceEqual(expected_tuple[1].tolist(), + actual_tuple[1].tolist()) + self.assertEqual(expected_tuple[2], actual_tuple[2]) + + def _does_not_match(self, expected, actual): + with self.assertRaises(AssertionError): + self._match(expected, actual) + + def testSaveRestore(self): + self._testReadWithBreaks([4]) + self._testReadWithBreaks([13]) + self._testReadWithBreaks([18]) + self._testReadWithBreaks([23]) + + def testSaveUnusedIterator(self): + self._testReadWithBreaks([0]) + + def testSaveFullyUsedIterator(self): + self._testReadWithBreaks([self._num_outputs]) + + def testMultipleBreaks(self): + self._testReadWithBreaks([0, 5, 9, 15, 25, 32]) + + def testIdempotence(self): + # Attempt to save iterator immediately after restoring. + self._testReadWithBreaks([1, 1, 5, 5, 5, 25, 32]) + + def testInitThenRestore(self): + self._testReadWithBreaks([0, 5, 9, 15, 25, 32], init_before_restore=True) + + def testRestoreExhaustedIterator(self): + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph() + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(self._num_outputs): + sess.run(get_next_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + self._save(sess, saver) + + with ops.Graph().as_default() as g: + saver = self._import_meta_graph() + init_op, get_next_op = self._get_iterator_ops_from_collection() + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + def testResetRestoredIterator(self): + expected = [] + # Collect ground truth containing all outputs. + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph() + break_point = self._num_outputs // 2 + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_point): + expected.append(sess.run(get_next_op)) + self._save(sess, saver) + for _ in range(self._num_outputs - break_point): + expected.append(sess.run(get_next_op)) + + actual = [] + # Restore from checkpoint and then run init_op. + with ops.Graph().as_default() as g: + saver = self._import_meta_graph() + init_op, get_next_op = self._get_iterator_ops_from_collection() + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + sess.run(init_op) + for _ in range(self._num_outputs): + actual.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + self._match(expected, actual) + + def testRestoreInModifiedGraph(self): + expected = [] + actual_without_restore = [] + actual = [] + break_point = 10 + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph(multiplier=15.0) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_point): + expected.append(sess.run(get_next_op)) + actual.extend(expected) + self._save(sess, saver) + for _ in range(self._num_outputs - break_point): + expected.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + # Collect outputs by running modified graph. + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph(multiplier=30.0) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(self._num_outputs): + actual_without_restore.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + # Restore the checkpoint in the modified graph. + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph(multiplier=30.0) + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + for _ in range(self._num_outputs - break_point): + actual.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + # Ensure the modified graph gets overridden when restoring checkpoint. + self._does_not_match(expected, actual_without_restore) + # Expect that the outputs are what we would expect if we ran the old + # graph. + self._match(expected, actual) + + # TODO(srbs): Add this test to dataset_serialization_test_base.py. + def testRestoreInEmptyGraph(self): + expected = [] + actual = [] + break_point = 10 + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph(multiplier=15.0) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_point): + sess.run(get_next_op) + self._save(sess, saver) + for _ in range(self._num_outputs - break_point): + expected.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + with ops.Graph().as_default() as g: + ds = self._build_ds() + output_types = ds.output_types + output_shapes = ds.output_shapes + + with ops.Graph().as_default() as g: + get_next_op, saver = self._build_empty_graph(output_types, output_shapes) + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + for _ in range(self._num_outputs - break_point): + actual.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + # Expect that the outputs are what we would expect if we ran the old + # graph. + self._match(expected, actual) + + def testDoNotBuildSaveable(self): + break_point = 10 + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph(multiplier=15.0) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_point): + sess.run(get_next_op) + self._save(sess, saver) + + expected = [] + # Collect ground truth by running modified graph. + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph(multiplier=30.0) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(self._num_outputs): + expected.append(sess.run(get_next_op)) + + actual = [] + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph( + multiplier=30.0, build_saveable=False) + with self.test_session(graph=g) as sess: + # Since the SaveableObject was not added to Saver's list + # of saveables, iterator state is not restored by saver.restore(). + self._restore(saver, sess) + with self.assertRaises(errors.FailedPreconditionError): + sess.run(get_next_op) + sess.run(init_op) + for _ in range(self._num_outputs): + actual.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + self._match(expected, actual) + + def testSaveStatefulFunction(self): + + def _build_ds(): + + def _map_fn(x): + return random_ops.random_uniform( + (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) + + return dataset_ops.Dataset.range(100).map(_map_fn) + + def _build_graph(): + ds = _build_ds() + iterator = ds.make_initializable_iterator() + + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + init_op = iterator.initializer + get_next = iterator.get_next() + saver = saver_lib.Saver(allow_empty=True) + return init_op, get_next, saver + + break_point = 10 + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = _build_graph() + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_point): + sess.run(get_next_op) + with self.assertRaises(errors.InvalidArgumentError): + self._save(sess, saver) + + def testCaptureVariableInMapFn(self): + + def _build_ds(): + counter_var = variable_scope.get_variable( + "counter", (), dtypes.int32, use_resource=True) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda _: counter_var.assign_add(1))) + + def _build_graph(): + ds = _build_ds() + iterator = ds.make_initializable_iterator() + + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + init_op = iterator.initializer + get_next = iterator.get_next() + saver = saver_lib.Saver(allow_empty=True) + return init_op, get_next, saver + + break_point = 10 + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = _build_graph() + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for _ in range(break_point): + sess.run(get_next_op) + with self.assertRaises(errors.InvalidArgumentError): + self._save(sess, saver) + + def testCaptureDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + + def _build_graph(): + ds = _build_ds() + iterator = ds.make_initializable_iterator() + + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + init_op = iterator.initializer + get_next = iterator.get_next() + saver = saver_lib.Saver(allow_empty=True) + return init_op, get_next, saver + + break_point = 10 + expected = [] + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = _build_graph() + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for _ in range(break_point): + sess.run(get_next_op) + self._save(sess, saver) + for _ in range(num_outputs - break_point): + expected.append(sess.run(get_next_op)) + + with ops.Graph().as_default() as g: + ds = _build_ds() + output_types = ds.output_types + output_shapes = ds.output_shapes + + actual = [] + with ops.Graph().as_default() as g: + get_next_op, saver = self._build_empty_graph(output_types, output_shapes) + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + for _ in range(num_outputs - break_point): + actual.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + self.assertSequenceEqual(expected, actual) + + def testBuildDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + + @function.Defun(dtypes.int32) + def defun_fn_deep(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) + + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + + def _build_graph(): + ds = _build_ds() + iterator = ds.make_initializable_iterator() + + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + init_op = iterator.initializer + get_next = iterator.get_next() + saver = saver_lib.Saver(allow_empty=True) + return init_op, get_next, saver + + break_point = 10 + expected = [] + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = _build_graph() + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for _ in range(break_point): + sess.run(get_next_op) + self._save(sess, saver) + for _ in range(num_outputs - break_point): + expected.append(sess.run(get_next_op)) + + with ops.Graph().as_default() as g: + ds = _build_ds() + output_types = ds.output_types + output_shapes = ds.output_shapes + + actual = [] + with ops.Graph().as_default() as g: + get_next_op, saver = self._build_empty_graph(output_types, output_shapes) + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + for _ in range(num_outputs - break_point): + actual.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + self.assertSequenceEqual(expected, actual) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 6783f7beb08..bf2e883bc53 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -50,21 +50,22 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ + "//tensorflow/contrib/data/python/ops:prefetching_py", "//tensorflow/python:array_ops", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/data/ops:iterator_ops", "//tensorflow/python/data/util:nest", "//tensorflow/python/eager:context", ], ) -py_test( +cuda_py_test( name = "datasets_test", srcs = ["datasets_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":datasets", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", @@ -240,6 +241,7 @@ py_test( "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:function", "//tensorflow/python/eager:test", ], ) diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 98e6983658a..b559cce6b12 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -20,11 +20,15 @@ from __future__ import print_function import threading +from tensorflow.contrib.data.python.ops import prefetching_ops +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import resource_variable_ops @@ -32,12 +36,12 @@ _uid_counter = 0 _uid_lock = threading.Lock() -def _iterator_shared_name(): +def _generate_shared_name(prefix): with _uid_lock: global _uid_counter uid = _uid_counter _uid_counter += 1 - return "eager_iterator_{}".format(uid) + return "{}_{}".format(prefix, uid) class Iterator(object): @@ -72,11 +76,12 @@ class Iterator(object): with ops.device("/device:CPU:0"): ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access self._output_types = dataset.output_types + self._output_shapes = dataset.output_shapes self._flat_output_types = nest.flatten(dataset.output_types) self._flat_output_shapes = nest.flatten(dataset.output_shapes) self._resource = gen_dataset_ops.iterator( container="", - shared_name=_iterator_shared_name(), + shared_name=_generate_shared_name("eager_iterator"), output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) gen_dataset_ops.make_iterator(ds_variant, self._resource) @@ -84,6 +89,35 @@ class Iterator(object): self._resource_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._resource, handle_device="/device:CPU:0") self._device = context.context().device_name + self._buffer_resource_handle = None + if not context.context().device_spec.device_type: + is_remote_device = False + else: + is_remote_device = context.context().device_spec.device_type != "CPU" + if is_remote_device: + with ops.device("/device:CPU:0"): + iter_string_handle = gen_dataset_ops.iterator_to_string_handle( + self._resource) + + @function.Defun(dtypes.string) + def remote_fn(h): + remote_iterator = iterator_ops.Iterator.from_string_handle( + h, self._output_types, self._output_shapes) + return remote_iterator.get_next() + + remote_fn.add_to_graph(None) + target = constant_op.constant("/device:CPU:0") + with ops.device(self._device): + self._buffer_resource_handle = prefetching_ops.function_buffering_resource( + string_arg=iter_string_handle, + f=remote_fn, + target_device=target, + buffer_size=10, + thread_pool_size=1, + container="", + shared_name=_generate_shared_name("function_buffer_resource")) + self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( + handle=self._buffer_resource_handle, handle_device=self._device) def __iter__(self): return self @@ -93,20 +127,20 @@ class Iterator(object): def next(self): """Return the next tf.Tensor from the dataset.""" - try: - # TODO(ashankar): Consider removing this ops.device() contextmanager - # and instead mimic ops placement in graphs: Operations on resource - # handles execute on the same device as where the resource is placed. - with ops.device("/device:CPU:0"): - ret = gen_dataset_ops.iterator_get_next( - self._resource, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - except errors.OutOfRangeError: - raise StopIteration - # Copies tensors from CPU to the current device if necessary. - # TODO(rohanj): This should be replaced by the mechanism to have the - # runtime's threads copy tensors to the destination device. with ops.device(self._device): - ret = [array_ops.identity(x) for x in ret] + try: + if self._buffer_resource_handle is not None: + ret = prefetching_ops.function_buffering_resource_get_next( + function_buffer_resource=self._buffer_resource_handle, + output_types=self._flat_output_types) + else: + # TODO(ashankar): Consider removing this ops.device() contextmanager + # and instead mimic ops placement in graphs: Operations on resource + # handles execute on the same device as where the resource is placed. + ret = gen_dataset_ops.iterator_get_next( + self._resource, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + except errors.OutOfRangeError: + raise StopIteration return nest.pack_sequence_as(self._output_types, ret) diff --git a/tensorflow/contrib/eager/python/evaluator_test.py b/tensorflow/contrib/eager/python/evaluator_test.py index 02f82cb2169..7d2274db9b0 100644 --- a/tensorflow/contrib/eager/python/evaluator_test.py +++ b/tensorflow/contrib/eager/python/evaluator_test.py @@ -87,7 +87,7 @@ class EvaluatorTest(test.TestCase): e.all_metric_results(logdir) - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].simple_value, 6.0) @@ -136,7 +136,7 @@ class EvaluatorTest(test.TestCase): variables.global_variables_initializer().run() e.run_evaluation(init_op, call_op, results_op) - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].simple_value, 6.0) diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py index 736a75332ff..14c82c87a72 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py @@ -95,7 +95,7 @@ class ResNet50GraphTest(tf.test.TestCase): sess.run([train_op, tf.contrib.summary.all_summary_ops()], feed_dict={images: np_images, labels: np_labels}) - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'loss') diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index d6389f2e385..582f4837c6f 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -103,7 +103,7 @@ class ResNet50Test(tf.test.TestCase): images, labels = random_batch(2) train_one_step(model, images, labels, optimizer) self.assertEqual(320, len(model.variables)) - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'loss') diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index b4f5973bd11..96eb1b4f2a0 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -72,7 +72,7 @@ class MetricsTest(test.TestCase): name="t0").as_default(), summary_ops.always_record_summaries(): m.result() # As a side-effect will write summaries. - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].simple_value, 37.0) diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index 1a5c6e8aec6..97eded7dca2 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -37,6 +37,406 @@ from tensorflow.python.training import training_util # functions in base.py which should be reused. +def _network_name_scope_naming(current_variable_scope): + """Name scope naming to match operation names to variable names. + + Used in Networks and also applied to non-Network Layers which are added to + Networks before being built. + + Args: + current_variable_scope: A VariableScope object. + Returns: + A name scope name. + """ + return current_variable_scope.name + "/" + + +class Network(base.Layer): + """Represents the composition of a set of Layers. + + TODO(josh11b,ashankar): + - Should "trainable" be changeable on the Network object? + - Do we allow add_variable in Network? + - Detect layers used in __call__ that weren't registered with track_layer. + - Convert inputs to __call__ to tensors. + - Prevent variables from being created after the first __call__? + (Think about restoring from a checkpoint). + """ + + def __init__(self, name=None): + if isinstance(name, variable_scope.VariableScope): + raise ValueError("VariableScopes are not valid Network names.") + if name is not None and "/" in name: + raise ValueError( + "Forward slashes ('/') are not allowed in Network names.") + super(Network, self).__init__(name=name) + self._layers = [] + self._sub_layer_name_uids = collections.defaultdict(int) + # Initially None, but set to False for networks which are first built as + # top-level. + self._first_parent = None # A weak reference to our first parent. + self._non_network_sublayers = [] + self._owned_layers = {} + # The scope to use if we end up without a parent. + self._default_parent_variable_scope = variable_scope.get_variable_scope() + # Hold on to the variable scope counts from init to check whether a scope + # with the name we want was ever created in our parent scope. Without this + # check we might have name collisions if the parent scope on init gets + # closed before build is called. + self._variable_scope_counts_on_init = ( + variable_scope._get_default_variable_store().variable_scopes_count) + + def _name_scope_name(self, current_variable_scope): + """Overrides Layer op naming to match variable naming.""" + return _network_name_scope_naming( + current_variable_scope=current_variable_scope) + + def _init_set_name(self, name): + # Anonymous Networks (name=None) defer setting a final name until they are + # (1) added to another Network, or (2) built/called (where (2) is only used + # for a "top level" network). + # + # However, if we were provided an explicit name (name is not None), that + # will always be the final name of the Network; if it turns out not to be + # unique or if variable names can't be prefixed by it we will throw an + # error. + self._name = name + self._base_name = None + + def _finalize_name(self, parent_network): + if not self._name: + # Were were not passed a name explicitly (or it was blank), so this is an + # anonymous Network. We make up a unique name. + if parent_network: + avoid_names = parent_network._owned_layers + name_uid_map = parent_network._sub_layer_name_uids + else: + name_uid_map = base._get_default_graph_uid_map() + # Figure out which names we have to avoid based on which variable scope + # we're nested in. + strip_name = self._default_parent_variable_scope.name + if strip_name: + strip_name += "/" + def _strip_on_init_scope(name): + if name.startswith(strip_name): + return name[len(strip_name):] + else: + return None + avoid_names = set( + _strip_on_init_scope(name) + for name in self._variable_scope_counts_on_init.keys() if name) + self._name, self._base_name = self._make_unique_name( + name_uid_map=name_uid_map, avoid_names=avoid_names, + namespace=self._default_parent_variable_scope.name, + zero_based=True) + if self._first_parent is None or (self._first_parent # False = no parent + and self._first_parent() is None): + # Save a pointer to the parent Network so that we can later check that the + # scope name we get is correct. + if not parent_network: + self._first_parent = parent_network + else: + self._first_parent = weakref.ref(parent_network) + + def _set_scope(self, scope=None): + if self._scope is None: + if not self._first_parent: + first_parent = self._first_parent + else: + first_parent = self._first_parent() + if first_parent is None: + # If we were never added to another Network, or that Network has beed + # garbage collected before being called, then we're a top-level Network. + self._finalize_name( + # Use False to make sure the value sticks and we don't inherit a + # parent if we're added to a network later. + parent_network=False) + if scope is not None: + raise ValueError("Networks may not be created with explicit scopes.") + if first_parent: + first_parent._set_scope() + parent_scope = first_parent._scope + else: + parent_scope = self._default_parent_variable_scope + with variable_scope.variable_scope(parent_scope) as parent_vs: + expected_scope_name = parent_vs.name + "/" + self._name + if expected_scope_name in self._variable_scope_counts_on_init: + raise ValueError( + ("A Network named '%s' already exists (or a variable_scope was " + "created with this name). Names must be unique.") % ( + self._name,)) + # Make sure variables with this prefix will be unique. + with variable_scope.variable_scope( + None, use_resource=True, default_name=self._name) as scope: + self._scope = scope + scope_name = scope.name + suffix_start = scope_name.rfind("/") + 1 + # rfind is -1 if there is no slash in the string, in which case the + # suffix starts at the beginning of the string (there is no prefix). + scope_suffix = scope_name[suffix_start:] + scope_prefix = scope_name[:suffix_start] + if scope_suffix != self._name: + raise ValueError( + ("A Network named '%s' already exists (or a variable_scope was " + "created with this name). Names must be unique.") % ( + self._name,)) + if (first_parent + and scope_prefix[:-1] != first_parent.scope_name): + raise ValueError( + ("Network variable names must match a nesting of sub-Network " + "names. Expected prefix '%s' from parent network, but got " + "'%s' when attempting to create a variable_scope for Network " + "'%s'. Likely an explicit variable_scope was inserted into " + "the nesting.") % ( + first_parent.scope_name, + scope_prefix[:-1], + self._name)) + elif not first_parent and scope_prefix: + # For the case when this Network is not nested inside any other + # Network, but is in a variable_scope. This Network's name takes on + # the full variable scope prefix. + self._name = scope_name + + for non_network_sublayer in self._non_network_sublayers: + self._set_scope_for_nonnetwork_sublayer(non_network_sublayer) + + def _set_scope_for_nonnetwork_sublayer(self, sublayer): + if sublayer._scope is None: + if sublayer._first_parent is None: + constituent_first_parent = None + else: + constituent_first_parent = sublayer._first_parent() + if constituent_first_parent: + constituent_first_parent._set_scope() + parent_scope = constituent_first_parent._scope + else: + self._finalize_name(False) + raise ValueError( + ("The parent of a Layer added to Network %s was garbage collected " + "before the Layer was built. If this limitation bothers you " + "please file a feature request.") % + (self.name,)) + with variable_scope.variable_scope(parent_scope): + # Horrid hack to make Layer variable names which are direct + # sub-layers of Networks conform to the Network variable naming + # conventions. + with variable_scope.variable_scope( + None, use_resource=True, + default_name=sublayer.name) as sub_scope: + sublayer._scope = sub_scope + # Also switch op naming for this Layer to match Network conventions, + # i.e. op naming matching variable naming. + sublayer._name_scope_name = _network_name_scope_naming + + @base.Layer.name.getter + def name(self): + if self._name is None: + raise ValueError( + "The network does not yet have a final name, but a name was " + "requested for it. Networks get a name when they are added to " + "another Network via track_layer, or when they are first " + "called/built.") + return self._name + + def track_layer(self, layer): + """Track a Layer in this Network. + + `Network` requires that all `Layer`s used in `call()` be tracked so that the + `Network` can export a complete list of variables. + + Args: + layer: A `tf.layers.Layer` object. + + Returns: + The passed in `layer`. + + Raises: + RuntimeError: If __init__ has not been called. + TypeError: If `layer` is the wrong type. + ValueError: If a `Layer` with the same name has already been added. + """ + if not hasattr(self, "_layers"): + raise RuntimeError("Need to call Network.__init__ before adding layers") + if not isinstance(layer, base.Layer): + raise TypeError( + "Network.track_layer() passed type %s, not a tf.layers.Layer" % + (type(layer),)) + if isinstance(layer, Network): + layer._finalize_name(parent_network=self) + else: + # `layer` is a non-Network, so it hasn't been named to follow Network + # conventions for contained Layers (i.e. the same conventions as for + # sub-Networks). This renaming is necessary to isolate Network variable + # naming from Layers constructed outside the Network and never added to it + # (because Layers are named globally). + if not layer.built: + if not hasattr(layer, "_first_parent"): + dereferenced_layer_first_parent = None + else: + dereferenced_layer_first_parent = layer._first_parent() + if dereferenced_layer_first_parent is None: + if layer._name != layer._base_name: + # If name and base_name do not match, then this Layer used anonymous + # naming and we have to rename it. Otherwise there's an explicit + # name, and we should respect it (subject to error checking). + layer._name, layer._base_name = layer._make_unique_name( + name_uid_map=self._sub_layer_name_uids, + avoid_names=self._owned_layers, + zero_based=True + # No namespace required, since we've specified our own UID map. + ) + layer._first_parent = weakref.ref(self) + self._non_network_sublayers.append(layer) + if (not layer.built + and layer._first_parent + and self is layer._first_parent()): + if layer.name in self._owned_layers: + if self._owned_layers[layer.name] is layer: + return layer + raise ValueError( + "Attempt to add two Layers with the name '%s' to the same Network." + % (layer.name)) + self._owned_layers[layer.name] = layer + self._layers.append(layer) + return layer + + def get_layer(self, name=None, index=None): + """Get a contained `tf.layers.Layer` either by name or index. + + Args: + name: String matching one of the names of a contained `Layer`. Note that + the names of `Layer`s added to `Network`s may not be unique when doing + layer sharing (i.e. adding a `Layer` to this `Network` which was already + added to another `Network`). The lowest index `Layer` with a matching + name will be returned. + index: Integer in [0, number of layers). Layers are assigned an index + by the order they are added. + + Returns: + A `tf.layers.Layer` object. + + Raises: + ValueError: If neither or both of 'index' or 'name' is specified, or the + lookup failed. + """ + if index is not None: + if name is not None: + raise ValueError("Exactly one of 'index' or 'name' must be provided") + if len(self._layers) <= index: + raise ValueError("Was asked to retrieve layer at index " + str(index) + + " but model only has " + str(len(self._layers)) + + " layers.") + else: + return self._layers[index] + else: + if not name: + raise ValueError("Provide either a layer name or layer index.") + for layer in self._layers: + if layer.name == name: + return layer + raise ValueError("No such layer: " + name) + + # The following methods are for implementing the Layer interface. + + @property + def weights(self): + # TODO(josh11b): Should this return a set or perform de-duplication of + # variables in the case of shared layers/variables that appear in + # multiple places in the Network? + weights = [] + for layer in self._layers: + weights += layer.weights + return weights + + @property + def trainable_weights(self): + weights = [] + for layer in self._layers: + weights += layer.trainable_weights + return weights + + @property + def non_trainable_weights(self): + weights = [] + for layer in self._layers: + weights += layer.non_trainable_weights + return weights + + @property + def trainable(self): + return True + + @trainable.setter + def trainable(self, value): + if not value: + # We believe it better to decide which layers & networks are trainable + # at the Trainer level than here. Otherwise you can run into trouble if a + # layer/network is shared between two models, but is trainable in one + # but not the other (like with adversarial networks). + raise AttributeError("cannot mark Network as not trainable") + + @property + def layers(self): + return self._layers + + def add_variable(self, name, shape, dtype=None, initializer=None, + regularizer=None, trainable=True, constraint=None): + raise RuntimeError( + "add_variable not supported in Network class yet. Please file an issue " + "at https://github.com/tensorflow/tensorflow/issues/new if this is " + "important to you") + + # TODO(josh11b): Support other Layer methods needed for graph mode, such as for + # losses and updates + + +class Sequential(Network): + """Represents a linear sequence of Layers or functions. + + The output of each layer/function is provided as the input to the next. + The inputs passed to `__call__` are passed to the inputs of the first + Layer, and it returns the outputs of the last Layer. + + Args: + layers_funcs: An optional sequence where each element is either a + tf.layers.Layer object or a callable. + name: An optional string name to use for this Network. + """ + + def __init__(self, layers_funcs=None, name=None): + super(Sequential, self).__init__(name=name) + self._layers_funcs = [] + if layers_funcs: + for l in layers_funcs: + self.add(l) + + def add(self, layer_func): + if isinstance(layer_func, base.Layer): + args = estimator_util.fn_args(layer_func.call) + self.track_layer(layer_func) + elif callable(layer_func): + args = estimator_util.fn_args(layer_func) + else: + raise TypeError( + "Sequential.add() takes only tf.layers.Layer objects or callables; " + "not '%s' of type '%s'." % (layer_func, type(layer_func))) + self._layers_funcs.append((("training" in args), layer_func)) + + def call(self, inputs, training=None): + """Call each Layer in the order they were added.""" + # TODO(josh11b): Support "mode" and maybe other arguments + if training is None: + for _, l in self._layers_funcs: + inputs = l(inputs) + else: + for has_training_arg, l in self._layers_funcs: + if has_training_arg: + inputs = l(inputs, training) + else: + inputs = l(inputs) + return inputs + + _DeferredRestoration = collections.namedtuple( "_DeferredRestoration", @@ -73,10 +473,10 @@ def _default_naming_conflict_error_message( "happen when using variable sharing (i.e. the Network contains Networks " "or Layers which were first added to another Network, and therefore " "have that Network's variable prefix). One solution is to pass " - "`map_func=lambda n: n` to Network.save and Network.restore to use " - "fully qualified variable names in the checkpoint, although this will " - "require that the variable prefix of the Network being restored into " - "is also '%s'. You may alternatively write an arbitrary mapping.") + "`map_func=lambda n: n` to save and restore to use fully qualified " + "variable names in the checkpoint, although this will require that the " + "variable prefix of the Network being restored into is also '%s'. You " + "may alternatively write an arbitrary mapping.") % ( network_name, network_scope_name, mapped_name, first_variable._shared_name, @@ -88,9 +488,9 @@ def _restore_custom_map_func_error_message( mapped_name, first_variable, second_variable, network_name, network_scope_name): return ( - ("The map_func passed to Network.restore for the Network '%s' " + ("The map_func passed to restore_network_checkpoint for the Network '%s' " "resulted in two variables named '%s' (originally '%s' and '%s'). Since " - "this is also an error on Network.save, this Network was " + "this is also an error when saving, this Network was " "probably not saved with this map_func. Note that map_func " "always maps from full variable names to checkpoint names; " "there is no need to specify an inverse mapping.\n\n" @@ -216,625 +616,255 @@ def _make_prefix_stripping_map_fn(scope_name): return _strip_variable_prefix -class Network(base.Layer): - """Represents the composition of a set of Layers. +def save_network_checkpoint( + network, save_path, global_step=None, map_func=None): + """Save variables from the Network to a checkpoint. - TODO(josh11b,ashankar): - - Should "trainable" be changeable on the Network object? - - Do we allow add_variable in Network? - - Detect layers used in __call__ that weren't registered with track_layer. - - Convert inputs to __call__ to tensors. - - Prevent variables from being created after the first __call__? - (Think about restoring from a checkpoint). + Args: + network: A Network object to save. + save_path: Either a checkpoint prefix or the name of a directory to save + the checkpoint in (in which case the checkpoint will be named based on + the Network name). + global_step: The global step to use when naming the checkpoint. If None + (default), we will first try to get the default global step. If that + fails because no default global step exists, then the checkpoint is + created without a global step suffix. + map_func: A function mapping fully qualified variable names + (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By + default (if `map_func=None`), the variable prefix for the network being + restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped + and all other variable names (shared with other Networks) are left + unchanged. + Returns: + The checkpoint prefix for the saved checkpoint, which may be passed to + `Network.restore`. + Raises: + ValueError: If the Network has not yet been called, or if map_func results + in a name collision. """ + if not network.built: + raise ValueError( + "Attempt to save the Network before it was first called. This means " + "variables have not yet been created, so there is nothing to save.") + network._set_scope() # scope_name should be available to map_funcs + if global_step is None: + global_step = training_util.get_global_step() + if os.path.isdir(save_path): + # If we were passed a directory, default to naming based on the Network + # name. + save_path = os.path.join(save_path, network.name.replace("/", "_")) + user_map_func = map_func + if map_func is None: + map_func = _make_prefix_stripping_map_fn(network.scope_name) + variable_map = {} + for variable in network.variables: + mapped_name = map_func(variable._shared_name) + if variable_map.setdefault(mapped_name, variable) is not variable: + if user_map_func is None: + # Instead of erroring out, we could just re-try and silently use the + # full variable names in the checkpoint. This could be odd for deeply + # nested sub-Networks (since the full prefix from the nesting would + # get added), so for now we'll let the user deal with this case. + raise ValueError(_default_naming_conflict_error_message( + mapped_name=mapped_name, + first_variable=variable_map[mapped_name], + second_variable=variable, + network_name=network.name, + network_scope_name=network.scope_name)) + else: + # The user passed their own problematic map_func. + raise ValueError( + ("The map_func passed to save_network_checkpoint for the Network " + "'%s' resulted in two variables named '%s' ('%s' and '%s'). Try " + "stripping less from the variable names, or renaming parts of " + "the Network. For reference, variables created by sub-Layers of " + "this Network are prefixed with '%s', but if they are re-used " + "after being added to another Network, they will have that " + "Network's full variable prefix instead.") % ( + network.name, mapped_name, + variable_map[mapped_name]._shared_name, + variable._shared_name, + network.scope_name)) + if context.in_eager_mode(): + sess = None + else: + sess = ops.get_default_session() + return saver_lib.Saver(variable_map).save( + sess=sess, save_path=save_path, write_meta_graph=False, + global_step=global_step) - def __init__(self, name=None): - if isinstance(name, variable_scope.VariableScope): - raise ValueError("VariableScopes are not valid Network names.") - if name is not None and "/" in name: - raise ValueError( - "Forward slashes ('/') are not allowed in Network names.") - super(Network, self).__init__(name=name) - self._layers = [] - self._sub_layer_name_uids = collections.defaultdict(int) - # Initially None, but set to False for networks which are first built as - # top-level. - self._first_parent = None # A weak reference to our first parent. - self._non_network_sublayers = [] - self._owned_layers = {} - # The scope to use if we end up without a parent. - self._default_parent_variable_scope = variable_scope.get_variable_scope() - # Hold on to the variable scope counts from init to check whether a scope - # with the name we want was ever created in our parent scope. Without this - # check we might have name collisions if the parent scope on init gets - # closed before build is called. - self._variable_scope_counts_on_init = ( - variable_scope._get_default_variable_store().variable_scopes_count) - self._custom_getter, self._deferred_restorations = ( + +def _add_deferred_restoration(layer, deferred_restoration): + """Add a deferred restoration to this Layer and all children. + + Restorations which are requested later have higher priority, and the highest + priority matching restoration is applied to a variable when it is created. + + Args: + layer: The Layer (may not be a Network) to operate on. + deferred_restoration: A _DeferredRestoration object. + """ + # Networks don't create variables at the moment, so this append isn't strictly + # necessary. We could get by with only adding deferred restorations to + # non-Network Layers. + if isinstance(layer, Network): + layer._set_scope() + # Make sure this Layer has a deferred restoration queue and a custom getter, + # then add our request to it. + if not hasattr(layer, "_custom_getter"): + assert not hasattr(layer, "_deferred_restorations") + layer._custom_getter, layer._deferred_restorations = ( _make_custom_getter_for_deferred_restorations()) + # We use set_custom_getter because it avoids recursively calling up the + # variable_scope tree. We've done the tree traversal ourselves and have added + # the request to each Layer which needs it. + layer._scope.set_custom_getter(layer._custom_getter) + layer._deferred_restorations.append(deferred_restoration) + if isinstance(layer, Network): + for sublayer in layer.layers: + if not isinstance(sublayer, Network): + layer._set_scope_for_nonnetwork_sublayer(sublayer) + _add_deferred_restoration(sublayer, deferred_restoration) - def _init_set_name(self, name): - # Anonymous Networks (name=None) defer setting a final name until they are - # (1) added to another Network, or (2) built/called (where (2) is only used - # for a "top level" network). - # - # However, if we were provided an explicit name (name is not None), that - # will always be the final name of the Network; if it turns out not to be - # unique or if variable names can't be prefixed by it we will throw an - # error. - self._name = name - self._base_name = None - def _finalize_name(self, parent_network): - if not self._name: - # Were were not passed a name explicitly (or it was blank), so this is an - # anonymous Network. We make up a unique name. - if parent_network: - avoid_names = parent_network._owned_layers - name_uid_map = parent_network._sub_layer_name_uids +def _restore_existing_variables(network, save_path, map_func, user_map_func): + """Use a standard Saver to restore existing variables from a checkpoint. + + Args: + network: A Network object to restore. + save_path: The checkpoint prefix or directory to read from. + map_func: The function to use when mapping from variable names to + checkpoint names. + user_map_func: The original map_func passed by the user, for error + checking. + Returns: + A dictionary mapping from checkpoint names to variable objects which have + been restored (for bookkeeping to avoid deferred restorations on these + variables). + Raises: + ValueError: If there is a name collision. + """ + existing_variables_by_checkpoint_name = {} + for variable in network.variables: + checkpoint_name = map_func(variable._shared_name) + if existing_variables_by_checkpoint_name.setdefault( + checkpoint_name, variable) is not variable: + if user_map_func is None: + raise ValueError(_default_naming_conflict_error_message( + mapped_name=checkpoint_name, + first_variable=existing_variables_by_checkpoint_name[ + checkpoint_name], + second_variable=variable, + network_name=network.name, + network_scope_name=network.scope_name)) else: - name_uid_map = base._get_default_graph_uid_map() - # Figure out which names we have to avoid based on which variable scope - # we're nested in. - strip_name = self._default_parent_variable_scope.name - if strip_name: - strip_name += "/" - def _strip_on_init_scope(name): - if name.startswith(strip_name): - return name[len(strip_name):] - else: - return None - avoid_names = set( - _strip_on_init_scope(name) - for name in self._variable_scope_counts_on_init.keys() if name) - self._name, self._base_name = self._make_unique_name( - name_uid_map=name_uid_map, avoid_names=avoid_names, - namespace=self._default_parent_variable_scope.name) - if self._first_parent is None or (self._first_parent # False = no parent - and self._first_parent() is None): - # Save a pointer to the parent Network so that we can later check that the - # scope name we get is correct. - if not parent_network: - self._first_parent = parent_network - else: - self._first_parent = weakref.ref(parent_network) - - def _set_scope(self, scope=None): - if self._scope is None: - if not self._first_parent: - first_parent = self._first_parent - else: - first_parent = self._first_parent() - if first_parent is None: - # If we were never added to another Network, or that Network has beed - # garbage collected before being called, then we're a top-level Network. - self._finalize_name( - # Use False to make sure the value sticks and we don't inherit a - # parent if we're added to a network later. - parent_network=False) - if scope is not None: - raise ValueError("Networks may not be created with explicit scopes.") - if first_parent: - first_parent._set_scope() - parent_scope = first_parent._scope - else: - parent_scope = self._default_parent_variable_scope - with variable_scope.variable_scope(parent_scope) as parent_vs: - expected_scope_name = parent_vs.name + "/" + self._name - if expected_scope_name in self._variable_scope_counts_on_init: - raise ValueError( - ("A Network named '%s' already exists (or a variable_scope was " - "created with this name). Names must be unique.") % ( - self._name,)) - # Make sure variables with this prefix will be unique. - with variable_scope.variable_scope( - None, use_resource=True, default_name=self._name) as scope: - self._scope = scope - scope_name = scope.name - suffix_start = scope_name.rfind("/") + 1 - # rfind is -1 if there is no slash in the string, in which case the - # suffix starts at the beginning of the string (there is no prefix). - scope_suffix = scope_name[suffix_start:] - scope_prefix = scope_name[:suffix_start] - if scope_suffix != self._name: - raise ValueError( - ("A Network named '%s' already exists (or a variable_scope was " - "created with this name). Names must be unique.") % ( - self._name,)) - if (first_parent - and scope_prefix[:-1] != first_parent.scope_name): - raise ValueError( - ("Network variable names must match a nesting of sub-Network " - "names. Expected prefix '%s' from parent network, but got " - "'%s' when attempting to create a variable_scope for Network " - "'%s'. Likely an explicit variable_scope was inserted into " - "the nesting.") % ( - first_parent.scope_name, - scope_prefix[:-1], - self._name)) - elif not first_parent and scope_prefix: - # For the case when this Network is not nested inside any other - # Network, but is in a variable_scope. This Network's name takes on - # the full variable scope prefix. - self._name = scope_name - - for non_network_sublayer in self._non_network_sublayers: - self._set_scope_for_nonnetwork_sublayer(non_network_sublayer) - - def _set_scope_for_nonnetwork_sublayer(self, sublayer): - if sublayer._scope is None: - if sublayer._first_parent is None: - constituent_first_parent = None - else: - constituent_first_parent = sublayer._first_parent() - if constituent_first_parent: - constituent_first_parent._set_scope() - parent_scope = constituent_first_parent._scope - else: - self._finalize_name(False) - raise ValueError( - ("The parent of a Layer added to Network %s was garbage collected " - "before the Layer was built. If this limitation bothers you " - "please file a feature request.") % - (self.name,)) - with variable_scope.variable_scope(parent_scope): - # Horrid hack to make Layer variable names which are direct - # sub-layers of Networks conform to the Network variable naming - # conventions. - with variable_scope.variable_scope( - None, use_resource=True, - default_name=sublayer.name) as sub_scope: - sublayer._scope = sub_scope - - @base.Layer.name.getter - def name(self): - if self._name is None: - raise ValueError( - "The network does not yet have a final name, but a name was " - "requested for it. Networks get a name when they are added to " - "another Network via track_layer, or when they are first " - "called/built.") - return self._name - - def track_layer(self, layer): - """Track a Layer in this Network. - - `Network` requires that all `Layer`s used in `call()` be tracked so that the - `Network` can export a complete list of variables. - - Args: - layer: A `tf.layers.Layer` object. - - Returns: - The passed in `layer`. - - Raises: - RuntimeError: If __init__ has not been called. - TypeError: If `layer` is the wrong type. - ValueError: If a `Layer` with the same name has already been added. - """ - if not hasattr(self, "_layers"): - raise RuntimeError("Need to call Network.__init__ before adding layers") - if not isinstance(layer, base.Layer): - raise TypeError( - "Network.track_layer() passed type %s, not a tf.layers.Layer" % - (type(layer),)) - if isinstance(layer, Network): - layer._finalize_name(parent_network=self) - else: - # `layer` is a non-Network, so it hasn't been named to follow Network - # conventions for contained Layers (i.e. the same conventions as for - # sub-Networks). This renaming is necessary to isolate Network variable - # naming from Layers constructed outside the Network and never added to it - # (because Layers are named globally). - if not layer.built: - if not hasattr(layer, "_first_parent"): - dereferenced_layer_first_parent = None - else: - dereferenced_layer_first_parent = layer._first_parent() - if dereferenced_layer_first_parent is None: - if layer._name != layer._base_name: - # If name and base_name do not match, then this Layer used anonymous - # naming and we have to rename it. Otherwise there's an explicit - # name, and we should respect it (subject to error checking). - layer._name, layer._base_name = layer._make_unique_name( - name_uid_map=self._sub_layer_name_uids, - avoid_names=self._owned_layers - # No namespace required, since we've specified our own UID map. - ) - layer._first_parent = weakref.ref(self) - self._non_network_sublayers.append(layer) - if (not layer.built - and layer._first_parent - and self is layer._first_parent()): - if layer.name in self._owned_layers: - if self._owned_layers[layer.name] is layer: - return layer - raise ValueError( - "Attempt to add two Layers with the name '%s' to the same Network." - % (layer.name)) - self._owned_layers[layer.name] = layer - self._layers.append(layer) - return layer - - def get_layer(self, name=None, index=None): - """Get a contained `tf.layers.Layer` either by name or index. - - Args: - name: String matching one of the names of a contained `Layer`. Note that - the names of `Layer`s added to `Network`s may not be unique when doing - layer sharing (i.e. adding a `Layer` to this `Network` which was already - added to another `Network`). The lowest index `Layer` with a matching - name will be returned. - index: Integer in [0, number of layers). Layers are assigned an index - by the order they are added. - - Returns: - A `tf.layers.Layer` object. - - Raises: - ValueError: If neither or both of 'index' or 'name' is specified, or the - lookup failed. - """ - if index is not None: - if name is not None: - raise ValueError("Exactly one of 'index' or 'name' must be provided") - if len(self._layers) <= index: - raise ValueError("Was asked to retrieve layer at index " + str(index) + - " but model only has " + str(len(self._layers)) + - " layers.") - else: - return self._layers[index] - else: - if not name: - raise ValueError("Provide either a layer name or layer index.") - for layer in self._layers: - if layer.name == name: - return layer - raise ValueError("No such layer: " + name) - - # The following methods are for implementing the Layer interface. - - @property - def weights(self): - # TODO(josh11b): Should this return a set or perform de-duplication of - # variables in the case of shared layers/variables that appear in - # multiple places in the Network? - weights = [] - for layer in self._layers: - weights += layer.weights - return weights - - @property - def trainable_weights(self): - weights = [] - for layer in self._layers: - weights += layer.trainable_weights - return weights - - @property - def non_trainable_weights(self): - weights = [] - for layer in self._layers: - weights += layer.non_trainable_weights - return weights - - @property - def trainable(self): - return True - - @trainable.setter - def trainable(self, value): - if not value: - # We believe it better to decide which layers & networks are trainable - # at the Trainer level than here. Otherwise you can run into trouble if a - # layer/network is shared between two models, but is trainable in one - # but not the other (like with adversarial networks). - raise AttributeError("cannot mark Network as not trainable") - - @property - def layers(self): - return self._layers - - def add_variable(self, name, shape, dtype=None, initializer=None, - regularizer=None, trainable=True, constraint=None): - raise RuntimeError( - "add_variable not supported in Network class yet. Please file an issue " - "at https://github.com/tensorflow/tensorflow/issues/new if this is " - "important to you") - - def save(self, save_path, global_step=None, map_func=None): - """Save variables from the Network to a checkpoint. - - Args: - save_path: Either a checkpoint prefix or the name of a directory to save - the checkpoint in (in which case the checkpoint will be named based on - the Network name). - global_step: The global step to use when naming the checkpoint. If None - (default), we will first try to get the default global step. If that - fails because no default global step exists, then the checkpoint is - created without a global step suffix. - map_func: A function mapping fully qualified variable names - (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By - default (if `map_func=None`), the variable prefix for the network being - restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped - and all other variable names (shared with other Networks) are left - unchanged. - Returns: - The checkpoint prefix for the saved checkpoint, which may be passed to - `Network.restore`. - Raises: - ValueError: If the Network has not yet been called, or if map_func results - in a name collision. - """ - if not self.built: - raise ValueError( - "Attempt to save the Network before it was first called. This means " - "variables have not yet been created, so there is nothing to save.") - self._set_scope() # scope_name should be available to map_funcs - if global_step is None: - global_step = training_util.get_global_step() - if os.path.isdir(save_path): - # If we were passed a directory, default to naming based on the Network - # name. - save_path = os.path.join(save_path, self.name.replace("/", "_")) - user_map_func = map_func - if map_func is None: - map_func = _make_prefix_stripping_map_fn(self.scope_name) - variable_map = {} - for variable in self.variables: - mapped_name = map_func(variable._shared_name) - if variable_map.setdefault(mapped_name, variable) is not variable: - if user_map_func is None: - # Instead of erroring out, we could just re-try and silently use the - # full variable names in the checkpoint. This could be odd for deeply - # nested sub-Networks (since the full prefix from the nesting would - # get added), so for now we'll let the user deal with this case. - raise ValueError(_default_naming_conflict_error_message( - mapped_name=mapped_name, - first_variable=variable_map[mapped_name], - second_variable=variable, - network_name=self.name, - network_scope_name=self.scope_name)) - else: - # The user passed their own problematic map_func. - raise ValueError( - ("The map_func passed to Network.save for the Network '%s' " - "resulted in two variables named '%s' ('%s' and '%s'). Try " - "stripping less from the variable names, or renaming parts of " - "the Network. For reference, variables created by sub-Layers of " - "this Network are prefixed with '%s', but if they are re-used " - "after being added to another Network, they will have that " - "Network's full variable prefix instead.") % ( - self.name, mapped_name, - variable_map[mapped_name]._shared_name, - variable._shared_name, - self.scope_name)) + raise ValueError(_restore_custom_map_func_error_message( + mapped_name=checkpoint_name, + first_variable=existing_variables_by_checkpoint_name[ + checkpoint_name], + second_variable=variable, + network_name=network.name, + network_scope_name=network.scope_name)) + if existing_variables_by_checkpoint_name: if context.in_eager_mode(): sess = None else: sess = ops.get_default_session() - return saver_lib.Saver(variable_map).save( - sess=sess, save_path=save_path, write_meta_graph=False, - global_step=global_step) + saver_lib.Saver(var_list=existing_variables_by_checkpoint_name).restore( + sess=sess, save_path=save_path) + return existing_variables_by_checkpoint_name - def _restore_existing_variables(self, save_path, map_func, user_map_func): - """Use a standard Saver to restore existing variables from a checkpoint. - Args: - save_path: The checkpoint prefix or directory to read from. - map_func: The function to use when mapping from variable names to - checkpoint names. - user_map_func: The original map_func passed by the user, for error - checking. - Returns: - A dictionary mapping from checkpoint names to variable objects which have - been restored (for bookkeeping to avoid deferred restorations on these - variables). - Raises: - ValueError: If there is a name collision. - """ - existing_variables_by_checkpoint_name = {} - for variable in self.variables: - checkpoint_name = map_func(variable._shared_name) - if existing_variables_by_checkpoint_name.setdefault( - checkpoint_name, variable) is not variable: - if user_map_func is None: - raise ValueError(_default_naming_conflict_error_message( - mapped_name=checkpoint_name, - first_variable=existing_variables_by_checkpoint_name[ - checkpoint_name], - second_variable=variable, - network_name=self.name, - network_scope_name=self.scope_name)) - else: - raise ValueError(_restore_custom_map_func_error_message( - mapped_name=checkpoint_name, - first_variable=existing_variables_by_checkpoint_name[ - checkpoint_name], - second_variable=variable, - network_name=self.name, - network_scope_name=self.scope_name)) - if existing_variables_by_checkpoint_name: - if context.in_eager_mode(): - sess = None - else: - sess = ops.get_default_session() - saver_lib.Saver(var_list=existing_variables_by_checkpoint_name).restore( - sess=sess, save_path=save_path) - return existing_variables_by_checkpoint_name - - def _set_restore_on_create(self, save_path, map_func, user_map_func, - existing_variables_by_checkpoint_name): - """If necessary, request deferred restorations of variables.""" - checkpoint_reader = checkpoint_utils.load_checkpoint(save_path) - checkpointed_variables_to_restore = {} - for checkpoint_name, _ in checkpoint_utils.list_variables(save_path): - if checkpoint_name in existing_variables_by_checkpoint_name: - # This variable was already created and restored. - continue - # Save the variable for later restoration in a custom getter. - checkpointed_variables_to_restore[checkpoint_name] = ( - checkpoint_reader.get_tensor(checkpoint_name)) - # Only set a deferred restoration if there are checkpoint variables which - # have not been assigned to existing variables. Note that this loses out on - # some opportunity for error checking, but avoids creating - # _DeferredRestoration objects once a Network has been built (so that - # restoring in a loop does not take increasing amounts of memory). - if checkpointed_variables_to_restore: - if context.in_eager_mode(): - sess = None - else: - sess = ops.get_default_session() - # We need a name for error messages. If we haven't been added to another - # Network yet, we're top-level. - self._finalize_name(False) - self._set_scope() - # Save a record of this restoration for use in the custom getter. - deferred_restoration = _DeferredRestoration( - map_func=map_func, - map_func_is_user=(user_map_func is not None), - checkpointed_variables_to_restore=checkpointed_variables_to_restore, - restored_variables={}, - session=sess, - network_name=self.name, - network_scope_name=self.scope_name) - self._deferred_restorations.append(deferred_restoration) - # Add the deferred registration to non-Network children, and request that - # Networks propagate the request to their children. - self._add_deferred_restoration(deferred_restoration) - - def _add_deferred_restoration(self, deferred_restoration): - """Add a deferred restoration to this Network and all children. - - Restorations which are requested later have higher priority, and the highest - priority matching restoration is applied to a variable when it is created. - - Args: - deferred_restoration: A _DeferredRestoration object. - """ - # Networks don't create variables at the moment, so this append isn't - # strictly necessary. We could get by with only adding deferred restorations - # to non-Network Layers. - self._set_scope() - # We use set_custom_getter because it avoids recursively calling up the - # variable_scope tree. We've done the tree traversal ourselves and have - # added the request to each Layer which needs it. - self._scope.set_custom_getter(self._custom_getter) - self._deferred_restorations.append(deferred_restoration) - for layer in self.layers: - if isinstance(layer, Network): - # For Networks, request that they propagate this deferred restoration - # to all of their children recursively. - layer._add_deferred_restoration(deferred_restoration) - else: - # For non-Network Layers, make sure they have a deferred restoration - # queue and a custom getter, then add our request to it. - if not hasattr(layer, "_custom_getter"): - assert not hasattr(layer, "_deferred_restorations") - layer._custom_getter, layer._deferred_restorations = ( - _make_custom_getter_for_deferred_restorations()) - self._set_scope_for_nonnetwork_sublayer(layer) - layer._scope.set_custom_getter(layer._custom_getter) - layer._deferred_restorations.append(deferred_restoration) - - def restore(self, save_path, map_func=None): - """Restore the Network from a checkpoint. - - If variables have already been created (typically when some or all of the - `Network` is built), they are assigned values from the checkpoint - immediately, overwriting any existing values (in graph mode the default - session is used for the assignments). - - If there are checkpoint entries which do not correspond to any existing - variables in the `Network`, these values are saved for deferred restoration; - their initial values will be the checkpointed values once they are - created. Requests for multiple deferred restorations behave the same way as - immediate restorations, in that later requests will take priority over - earlier requests relevant to the same variable. - - If this `Network` shares `Layer`s with another network, those `Layer`s will - also have their variables restored from the checkpoint. - - Args: - save_path: The return value of `Network.save`, or a directory to search - for a checkpoint. - map_func: A function mapping fully qualified variable names - (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By - default (if `map_func=None`), the variable prefix for the network being - restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped - and all other variable names (shared with other Networks) are left - unchanged. Note that this is the _same_ map_func as `Network.save`, not - an inverse mapping. - """ - self._finalize_name(parent_network=False) - self._set_scope() # scope_name should be available to map_funcs - if os.path.isdir(save_path): - # If we don't have a name yet, set no parent. - save_path = os.path.join(save_path, self.name.replace("/", "_")) - user_map_func = map_func - if map_func is None: - map_func = _make_prefix_stripping_map_fn(self.scope_name) - # Step one is to restore any existing variables from the checkpoint. - existing_variables_by_checkpoint_name = self._restore_existing_variables( - save_path=save_path, +def _set_restore_on_create(network, save_path, map_func, user_map_func, + existing_variables_by_checkpoint_name): + """If necessary, request deferred restorations of variables.""" + checkpoint_reader = checkpoint_utils.load_checkpoint(save_path) + checkpointed_variables_to_restore = {} + for checkpoint_name, _ in checkpoint_utils.list_variables(save_path): + if checkpoint_name in existing_variables_by_checkpoint_name: + # This variable was already created and restored. + continue + # Save the variable for later restoration in a custom getter. + checkpointed_variables_to_restore[checkpoint_name] = ( + checkpoint_reader.get_tensor(checkpoint_name)) + # Only set a deferred restoration if there are checkpoint variables which + # have not been assigned to existing variables. Note that this loses out on + # some opportunity for error checking, but avoids creating + # _DeferredRestoration objects once a Network has been built (so that + # restoring in a loop does not take increasing amounts of memory). + if checkpointed_variables_to_restore: + if context.in_eager_mode(): + sess = None + else: + sess = ops.get_default_session() + # We need a name for error messages. If we haven't been added to another + # Network yet, we're top-level. + network._finalize_name(False) + network._set_scope() + # Save a record of this restoration for use in the custom getter. + deferred_restoration = _DeferredRestoration( map_func=map_func, - user_map_func=user_map_func) - # Step two is to set a custom getter which restores variables on creation, - # for those variables which have not been added to sub-Layers yet. - self._set_restore_on_create( - save_path=save_path, - map_func=map_func, - user_map_func=user_map_func, - existing_variables_by_checkpoint_name=( - existing_variables_by_checkpoint_name)) - - # TODO(josh11b): Support other Layer methods needed for graph mode, such as for - # losses and updates + map_func_is_user=(user_map_func is not None), + checkpointed_variables_to_restore=checkpointed_variables_to_restore, + restored_variables={}, + session=sess, + network_name=network.name, + network_scope_name=network.scope_name) + # Add the deferred registration to non-Network children, and request that + # Networks propagate the request to their children. + _add_deferred_restoration(network, deferred_restoration) -class Sequential(Network): - """Represents a linear sequence of Layers or functions. +def restore_network_checkpoint(network, save_path, map_func=None): + """Restore the Network from a checkpoint. - The output of each layer/function is provided as the input to the next. - The inputs passed to `__call__` are passed to the inputs of the first - Layer, and it returns the outputs of the last Layer. + If variables have already been created (typically when some or all of the + `Network` is built), they are assigned values from the checkpoint immediately, + overwriting any existing values (in graph mode the default session is used for + the assignments). + + If there are checkpoint entries which do not correspond to any existing + variables in the `Network`, these values are saved for deferred restoration; + their initial values will be the checkpointed values once they are + created. Requests for multiple deferred restorations behave the same way as + immediate restorations, in that later requests will take priority over earlier + requests relevant to the same variable. + + If this `Network` shares `Layer`s with another network, those `Layer`s will + also have their variables restored from the checkpoint. Args: - layers_funcs: An optional sequence where each element is either a - tf.layers.Layer object or a callable. - name: An optional string name to use for this Network. + network: A Network object to restore. + save_path: The return value of `tfe.save_network_checkpoint`, or a directory + to search for a checkpoint. + map_func: A function mapping fully qualified variable names + (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By + default (if `map_func=None`), the variable prefix for the network being + restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped + and all other variable names (shared with other Networks) are left + unchanged. Note that this is the _same_ map_func as + `tfe.save_network_checkpoint`, not an inverse mapping. """ - - def __init__(self, layers_funcs=None, name=None): - super(Sequential, self).__init__(name=name) - self._layers_funcs = [] - if layers_funcs: - for l in layers_funcs: - self.add(l) - - def add(self, layer_func): - if isinstance(layer_func, base.Layer): - args = estimator_util.fn_args(layer_func.call) - self.track_layer(layer_func) - elif callable(layer_func): - args = estimator_util.fn_args(layer_func) - else: - raise TypeError( - "Sequential.add() takes only tf.layers.Layer objects or callables; " - "not '%s' of type '%s'." % (layer_func, type(layer_func))) - self._layers_funcs.append((("training" in args), layer_func)) - - def call(self, inputs, training=None): - """Call each Layer in the order they were added.""" - # TODO(josh11b): Support "mode" and maybe other arguments - if training is None: - for _, l in self._layers_funcs: - inputs = l(inputs) - else: - for has_training_arg, l in self._layers_funcs: - if has_training_arg: - inputs = l(inputs, training) - else: - inputs = l(inputs) - return inputs + network._finalize_name(parent_network=False) + network._set_scope() # scope_name should be available to map_funcs + if os.path.isdir(save_path): + # If we don't have a name yet, set no parent. + save_path = os.path.join(save_path, network.name.replace("/", "_")) + user_map_func = map_func + if map_func is None: + map_func = _make_prefix_stripping_map_fn(network.scope_name) + # Step one is to restore any existing variables from the checkpoint. + existing_variables_by_checkpoint_name = _restore_existing_variables( + network=network, + save_path=save_path, + map_func=map_func, + user_map_func=user_map_func) + # Step two is to set a custom getter which restores variables on creation, + # for those variables which have not been added to sub-Layers yet. + _set_restore_on_create( + network=network, + save_path=save_path, + map_func=map_func, + user_map_func=user_map_func, + existing_variables_by_checkpoint_name=( + existing_variables_by_checkpoint_name)) diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index 1127055c050..e7835a63e6d 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -19,9 +19,12 @@ from __future__ import print_function import gc from tensorflow.contrib.eager.python import network +from tensorflow.python.eager import context +from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.layers import core from tensorflow.python.ops import math_ops @@ -46,8 +49,8 @@ class NetworkTest(test.TestCase): def _save_modify_load_network_built(self, net, global_step=None): checkpoint_directory = self.get_temp_dir() - checkpoint_path = net.save( - save_path=checkpoint_directory, global_step=global_step) + checkpoint_path = network.save_network_checkpoint( + network=net, save_path=checkpoint_directory, global_step=global_step) input_value = constant_op.constant([[42.0]]) original_output = self.evaluate(net(input_value)) for var in net.variables: @@ -56,13 +59,13 @@ class NetworkTest(test.TestCase): self.evaluate(net(input_value)), original_output) # Either the returned explicit checkpoint path or the directory should work. - net.restore(save_path=checkpoint_directory) + network.restore_network_checkpoint(net, save_path=checkpoint_directory) self.assertAllEqual( original_output, self.evaluate(net(input_value))) for var in net.variables: self.evaluate(var.assign(var + 2.)) - net.restore(save_path=checkpoint_path) + network.restore_network_checkpoint(net, save_path=checkpoint_path) self.assertAllEqual( original_output, self.evaluate(net(input_value))) @@ -85,13 +88,30 @@ class NetworkTest(test.TestCase): result = net(constant_op.constant([[2.0]])) self.assertEqual(34.0, self.evaluate(result)) + # TODO(akshayka): This test should be changed once an API for compiling + # `call` into a defun is implemented. + def testReplacingNetworkCallWithDefun(self): + net = MyNetwork(name="abcd") + x = constant_op.constant([[2.0]]) + net(x) # Force variables to be created. + self.evaluate(net.trainable_variables[0].assign([[17.0]])) + + net.call = function.defun(net.call) + result = net(x) # Build and execute the TensorFlow function + self.assertEqual(34.0, self.evaluate(result)) + + # Force the creation of another TensorFlow function by changing input shape + y = constant_op.constant([[1.0], [2.0]]) + result = net(y) + self.assertAllEqual([[17.0], [34.0]], self.evaluate(result)) + # TODO(allenl): This test creates garbage in some Python versions @test_util.run_in_graph_and_eager_modes() def testNetworkSaveRestoreAlreadyBuilt(self): net = MyNetwork(name="abcd") with self.assertRaisesRegexp( ValueError, "Attempt to save the Network before it was first called"): - net.save(self.get_temp_dir()) + network.save_network_checkpoint(net, self.get_temp_dir()) net(constant_op.constant([[2.0]])) self.evaluate(net.trainable_variables[0].assign([[17.0]])) self._save_modify_load_network_built(net, global_step=None) @@ -105,7 +125,7 @@ class NetworkTest(test.TestCase): self.evaluate(net.variables[0].assign([[3.]])) default_global_step = training_util.get_or_create_global_step() self.evaluate(default_global_step.assign(4242)) - save_path = net.save(self.get_temp_dir()) + save_path = network.save_network_checkpoint(net, self.get_temp_dir()) self.assertIn("abcd-4242", save_path) # TODO(allenl): This test creates garbage in some Python versions @@ -116,16 +136,43 @@ class NetworkTest(test.TestCase): test_input = constant_op.constant([[2.0]]) net1(test_input) self.evaluate(net1.trainable_variables[0].assign([[17.0]])) - save_path = net1.save(save_dir) + save_path = network.save_network_checkpoint(net1, save_dir) # With a pre-build restore we should have the same value. net2 = MyNetwork() - net2.restore(save_path) + network.restore_network_checkpoint(net2, save_path) self.assertAllEqual(self.evaluate(net1(test_input)), self.evaluate(net2(test_input))) self.assertIsNot(net1.variables[0], net2.variables[0]) self.assertAllEqual(self.evaluate(net1.variables[0]), self.evaluate(net2.variables[0])) + @test_util.run_in_graph_and_eager_modes() + def testNetworkMatchesLayerVariableNames(self): + zero = constant_op.constant([[0.]]) + layer_one = core.Dense(1, use_bias=False) + layer_one(zero) + layer_two = core.Dense(1, use_bias=False) + layer_two(zero) + + class TwoLayerNet(network.Network): + + def __init__(self, name=None): + super(TwoLayerNet, self).__init__(name=name) + self.first = self.track_layer(core.Dense( + 1, use_bias=False)) + self.second = self.track_layer(core.Dense( + 1, use_bias=False)) + + def call(self, x): + return self.second(self.first(x)) + + net = TwoLayerNet() + net(zero) + self.assertEqual("two_layer_net/" + layer_one.variables[0].name, + net.first.variables[0].name) + self.assertEqual("two_layer_net/" + layer_two.variables[0].name, + net.second.variables[0].name) + @test_util.run_in_graph_and_eager_modes() def testLoadIntoUnbuiltSharedLayer(self): @@ -173,14 +220,15 @@ class NetworkTest(test.TestCase): # Re-map the variable names so that with default restore mapping we'll # attempt to restore into the unbuilt Layer. name_mapping = { - "checkpoint_creator/first_layer/kernel": "owner_1/first_layer/kernel", + "checkpoint_creator/first_layer/kernel": "owner/first_layer/kernel", "checkpoint_creator/second_layer/kernel": "second_layer/kernel", } - save_path = checkpoint_creator.save( + save_path = network.save_network_checkpoint( + checkpoint_creator, self.get_temp_dir(), map_func=lambda full_name: name_mapping[full_name]) load_into = User(use_layer=first_owner.first) - load_into.restore(save_path) + network.restore_network_checkpoint(load_into, save_path) self.assertEqual(0, len(first_owner.variables)) self.assertAllEqual(self.evaluate(checkpoint_creator(one)), self.evaluate(load_into(one))) @@ -196,12 +244,13 @@ class NetworkTest(test.TestCase): del first_owner gc.collect() def _restore_map_func(original_name): - if original_name.startswith("owner_1"): - return original_name.replace("owner_1", "owner_2") + if original_name.startswith("owner/"): + return original_name.replace("owner/", "owner_1/") else: - return "user_2/" + original_name + return "user_1/" + original_name with self.assertRaisesRegexp(ValueError, "garbage collected"): - load_into.restore(save_path, map_func=_restore_map_func) + network.restore_network_checkpoint( + load_into, save_path, map_func=_restore_map_func) @test_util.run_in_graph_and_eager_modes() def testRestoreIntoSubNetwork(self): @@ -221,17 +270,18 @@ class NetworkTest(test.TestCase): whole_model_saver(one) self.evaluate(whole_model_saver.variables[0].assign([[15.]])) self.evaluate(whole_model_saver.variables[1].assign([[16.]])) - whole_model_checkpoint = whole_model_saver.save(self.get_temp_dir()) + whole_model_checkpoint = network.save_network_checkpoint( + whole_model_saver, self.get_temp_dir()) save_from = MyNetwork() save_from(one) self.evaluate(save_from.variables[0].assign([[5.]])) - checkpoint = save_from.save(self.get_temp_dir()) + checkpoint = network.save_network_checkpoint(save_from, self.get_temp_dir()) save_into_parent = Parent() - save_into_parent.restore(whole_model_checkpoint) - save_into_parent.first.restore(checkpoint) - save_into_parent.first.restore(checkpoint) # deferred loading multiple - # times is fine + network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint) + network.restore_network_checkpoint(save_into_parent.first, checkpoint) + # deferred loading multiple times is fine + network.restore_network_checkpoint(save_into_parent.first, checkpoint) save_into_parent(one) # deferred loading self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[0])) self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1])) @@ -240,9 +290,9 @@ class NetworkTest(test.TestCase): # (deferred restoration should happen the same way non-deferred happens, # with later restorations overwriting older ones). save_into_parent = Parent() - save_into_parent.first.restore(checkpoint) # deferred loading multiple - # times is fine - save_into_parent.restore(whole_model_checkpoint) + # deferred loading multiple times is fine + network.restore_network_checkpoint(save_into_parent.first, checkpoint) + network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint) save_into_parent(one) # deferred loading # We've overwritten the sub-Network restore. self.assertAllEqual([[15.]], self.evaluate(save_into_parent.variables[0])) @@ -250,12 +300,12 @@ class NetworkTest(test.TestCase): self.evaluate(save_into_parent.variables[0].assign([[3.]])) self.evaluate(save_into_parent.variables[1].assign([[4.]])) - save_into_parent.second.restore(checkpoint) + network.restore_network_checkpoint(save_into_parent.second, checkpoint) self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[1])) with self.assertRaisesRegexp(errors_impl.NotFoundError, "not found in checkpoint"): # The checkpoint is incompatible. - save_into_parent.restore(checkpoint) + network.restore_network_checkpoint(save_into_parent, checkpoint) @test_util.run_in_graph_and_eager_modes() def testCustomMapCollisionErrors(self): @@ -277,31 +327,36 @@ class NetworkTest(test.TestCase): self.evaluate(make_checkpoint.variables[1].assign([[3.]])) with self.assertRaisesRegexp( ValueError, - "The map_func passed to Network.save for the Network 'parent_1' " - "resulted in two variables named 'foo'"): - make_checkpoint.save(self.get_temp_dir(), map_func=lambda n: "foo") - checkpoint = make_checkpoint.first.save( - self.get_temp_dir(), map_func=lambda n: "foo") + "The map_func passed to save_network_checkpoint for the Network " + "'parent' resulted in two variables named 'foo'"): + network.save_network_checkpoint( + make_checkpoint, self.get_temp_dir(), map_func=lambda n: "foo") + checkpoint = network.save_network_checkpoint( + network=make_checkpoint.first, + save_path=self.get_temp_dir(), + map_func=lambda n: "foo") loader = Parent() - loader.restore(checkpoint, map_func=lambda n: "foo") + network.restore_network_checkpoint( + loader, checkpoint, map_func=lambda n: "foo") with self.assertRaisesRegexp( ValueError, - ("The map_func passed to Network.restore for the Network" - " 'parent_2' resulted in two variables named 'foo'")): + ("The map_func passed to restore_network_checkpoint for the Network" + " 'parent_1' resulted in two variables named 'foo'")): loader(one) loader = Parent() loader(one) with self.assertRaisesRegexp( ValueError, - ("The map_func passed to Network.restore for the Network" - " 'parent_3' resulted in two variables named 'foo'")): - loader.restore(checkpoint, map_func=lambda n: "foo") + ("The map_func passed to restore_network_checkpoint for the Network" + " 'parent_2' resulted in two variables named 'foo'")): + network.restore_network_checkpoint( + loader, checkpoint, map_func=lambda n: "foo") @test_util.run_in_graph_and_eager_modes() def testDefaultMapCollisionErrors(self): one = constant_op.constant([[1.]]) - first = core.Dense(1, name="dense_1", use_bias=False) + first = core.Dense(1, name="dense", use_bias=False) first(one) class Parent(network.Network): @@ -322,8 +377,8 @@ class NetworkTest(test.TestCase): with self.assertRaisesRegexp( ValueError, ("The default checkpoint variable name mapping strategy for Network " - "'parent_1' resulted in a naming conflict.")): - make_checkpoint.save(self.get_temp_dir()) + "'parent' resulted in a naming conflict.")): + network.save_network_checkpoint(make_checkpoint, self.get_temp_dir()) class Compatible(network.Network): @@ -337,14 +392,15 @@ class NetworkTest(test.TestCase): successful_checkpoint = Compatible() successful_checkpoint(one) self.evaluate(successful_checkpoint.variables[0].assign([[-1.]])) - checkpoint_path = successful_checkpoint.save(self.get_temp_dir()) + checkpoint_path = network.save_network_checkpoint( + successful_checkpoint, self.get_temp_dir()) load_checkpoint = Parent() load_checkpoint(one) with self.assertRaisesRegexp( ValueError, ("The default checkpoint variable name mapping strategy for Network " - "'parent_2' resulted in a naming conflict.")): - load_checkpoint.restore(checkpoint_path) + "'parent_1' resulted in a naming conflict.")): + network.restore_network_checkpoint(load_checkpoint, checkpoint_path) def testNoReferenceCyclesAfterCall(self): @@ -398,6 +454,36 @@ class NetworkTest(test.TestCase): self.assertIsInstance(net.trainable_weights[0], resource_variable_ops.ResourceVariable) + def testGraphOpNames(self): + """Network operation names should match variable naming.""" + + def _check_op_prefixes(expected_prefix, checked_ops): + for operation in ops.get_default_graph().get_operations(): + if operation.name == "ignore": + continue + if operation.name in checked_ops: + continue + checked_ops.add(operation.name) + self.assertStartsWith(expected_start=expected_prefix, + actual=operation.name) + self.assertNotIn("my_network", operation.name[len(expected_prefix):]) + self.assertNotIn("dense", operation.name[len(expected_prefix):]) + + with context.graph_mode(): + net = MyNetwork() + zero = constant_op.constant([[0.]], name="ignore") + net(zero) + checked_ops = set() + _check_op_prefixes(expected_prefix="my_network/dense/", + checked_ops=checked_ops) + net.net2 = net.track_layer(MyNetwork()) + net.net2(zero) + _check_op_prefixes(expected_prefix="my_network/my_network/dense/", + checked_ops=checked_ops) + MyNetwork()(zero) + _check_op_prefixes(expected_prefix="my_network_1/dense/", + checked_ops=checked_ops) + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testDuplicateNameError(self): one = constant_op.constant([[1.]]) @@ -414,25 +500,25 @@ class NetworkTest(test.TestCase): # Naming happens in the order of first build rather than the order of # construction, but for clarity they're the same here and construction is # annotated. - outside_net_before = MyNetwork() # name=my_network_1 + outside_net_before = MyNetwork() # name=my_network outside_net_before(one) captured_scope = variable_scope.get_variable_scope() with variable_scope.variable_scope("outside_scope"): - net1 = MyNetwork() # name=outside_scope/my_network_1 + net1 = MyNetwork() # name=outside_scope/my_network net1(one) name_conflict1 = MyNetwork(name="name_conflict") # fine, unique so far name_conflict2 = MyNetwork(name="name_conflict") # error on build with variable_scope.variable_scope("inside_scope"): # No issue here since the name is unique within its scope. name_conflict3 = MyNetwork(name="name_conflict") - net2 = MyNetwork() # name=outside_scope/my_network_3 to avoid the - # variable_scope my_network_2 below. + net2 = MyNetwork() # name=outside_scope/my_network_2 to avoid the + # variable_scope my_network_1 below. vs_name_conflict = MyNetwork(name="vs_name_conflict") # conflict below with variable_scope.variable_scope("intervening_scope"): with variable_scope.variable_scope(captured_scope): with variable_scope.variable_scope("outside_scope"): name_conflict4 = MyNetwork(name="name_conflict") # error on build - with variable_scope.variable_scope("my_network_2"): + with variable_scope.variable_scope("my_network_1"): pass with variable_scope.variable_scope("vs_name_conflict"): pass @@ -452,35 +538,35 @@ class NetworkTest(test.TestCase): self.assertEqual("outside_scope/name_conflict", name_conflict1.name) self.assertStartsWith( - expected_start="outside_scope/name_conflict/dense_1/", + expected_start="outside_scope/name_conflict/dense/", actual=name_conflict1.variables[0].name) self.assertEqual("outside_scope/inside_scope/name_conflict", name_conflict3.name) self.assertStartsWith( - expected_start="outside_scope/inside_scope/name_conflict/dense_1/", + expected_start="outside_scope/inside_scope/name_conflict/dense/", actual=name_conflict3.variables[0].name) - self.assertEqual("outside_scope/my_network_1", net1.name) + self.assertEqual("outside_scope/my_network", net1.name) self.assertStartsWith( - expected_start="outside_scope/my_network_1/dense_1/", + expected_start="outside_scope/my_network/dense/", actual=net1.trainable_weights[0].name) - self.assertEqual("outside_scope/my_network_3", net2.name) + self.assertEqual("outside_scope/my_network_2", net2.name) self.assertStartsWith( - expected_start="outside_scope/my_network_3/dense_1/", + expected_start="outside_scope/my_network_2/dense/", actual=net2.trainable_weights[0].name) net3(one) - self.assertEqual("outside_scope/my_network_4", net3.name) + self.assertEqual("outside_scope/my_network_3", net3.name) self.assertStartsWith( - expected_start="outside_scope/my_network_4/dense_1/", + expected_start="outside_scope/my_network_3/dense/", actual=net3.trainable_weights[0].name) outside_net_after = MyNetwork() outside_net_after(one) - self.assertEqual("my_network_1", outside_net_before.name) + self.assertEqual("my_network", outside_net_before.name) self.assertStartsWith( - expected_start="my_network_1/dense_1/", + expected_start="my_network/dense/", actual=outside_net_before.trainable_weights[0].name) - self.assertEqual("my_network_2", outside_net_after.name) + self.assertEqual("my_network_1", outside_net_after.name) self.assertStartsWith( - expected_start="my_network_2/dense_1/", + expected_start="my_network_1/dense/", actual=outside_net_after.trainable_weights[0].name) @test_util.run_in_graph_and_eager_modes() @@ -490,21 +576,21 @@ class NetworkTest(test.TestCase): net = MyNetwork() net(constant_op.constant([[2.0]])) self.evaluate(net.variables[0].assign([[42.]])) - self.assertEqual(net.name, "scope1/scope2/my_network_1") + self.assertEqual(net.name, "scope1/scope2/my_network") self.assertStartsWith( - expected_start="scope1/scope2/my_network_1/dense_1/", + expected_start="scope1/scope2/my_network/dense/", actual=net.trainable_weights[0].name) - save_path = net.save(self.get_temp_dir()) - self.assertIn("scope1_scope2_my_network_1", save_path) + save_path = network.save_network_checkpoint(net, self.get_temp_dir()) + self.assertIn("scope1_scope2_my_network", save_path) restore_net = MyNetwork() # Delayed restoration - restore_net.restore(save_path) + network.restore_network_checkpoint(restore_net, save_path) restore_net(constant_op.constant([[1.0]])) self.assertAllEqual([[42.]], self.evaluate(restore_net.variables[0])) self.evaluate(restore_net.variables[0].assign([[-1.]])) # Immediate restoration - restore_net.restore(save_path) + network.restore_network_checkpoint(restore_net, save_path) self.assertAllEqual([[42.]], self.evaluate(restore_net.variables[0])) @@ -523,7 +609,7 @@ class NetworkTest(test.TestCase): one = constant_op.constant([[1.]]) net = ParentNetwork() net(one) - self.assertStartsWith(expected_start="parent_network_1/explicit_name/", + self.assertStartsWith(expected_start="parent_network/explicit_name/", actual=net.trainable_weights[0].name) self.assertEqual("explicit_name", net.first.name) @@ -578,15 +664,15 @@ class NetworkTest(test.TestCase): # locally so that previous Layer consutrciton does not interfere with # variable naming (e.g. add a Layer construction before the Network, # suddenly your previously saved checkpoint is incompatible). - self.assertEqual("dense_1", net1.l1.name) - self.assertEqual("dense_1", net2.l1.name) + self.assertEqual("dense", net1.l1.name) + self.assertEqual("dense", net2.l1.name) self.evaluate(net1.trainable_weights[0].assign([[1.]])) self.evaluate(net2.trainable_weights[0].assign([[2.]])) self.assertEqual(2., self.evaluate(net2.trainable_weights[0])) self.assertEqual(1., self.evaluate(net1.trainable_weights[0])) - self.assertStartsWith(expected_start="my_network_1/dense_1/", + self.assertStartsWith(expected_start="my_network/dense/", actual=net1.trainable_weights[0].name) - self.assertStartsWith(expected_start="my_network_2/dense_1/", + self.assertStartsWith(expected_start="my_network_1/dense/", actual=net2.trainable_weights[0].name) @test_util.run_in_graph_and_eager_modes() @@ -607,31 +693,31 @@ class NetworkTest(test.TestCase): one = constant_op.constant([[1.]]) net = ParentNetwork() net(one) - self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense", + self.assertStartsWith(expected_start="parent_network/my_network/dense", actual=net.trainable_weights[0].name) - self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense", + self.assertStartsWith(expected_start="parent_network/my_network/dense", actual=net.first.trainable_weights[0].name) - self.assertStartsWith(expected_start="parent_network_1/my_network_2/dense", + self.assertStartsWith(expected_start="parent_network/my_network_1/dense", actual=net.trainable_weights[1].name) - self.assertStartsWith(expected_start="parent_network_1/my_network_2/dense", + self.assertStartsWith(expected_start="parent_network/my_network_1/dense", actual=net.second.trainable_weights[0].name) - self.assertEqual("parent_network_1", net.name) - self.assertEqual("my_network_1", net.first.name) - self.assertEqual("my_network_2", net.second.name) + self.assertEqual("parent_network", net.name) + self.assertEqual("my_network", net.first.name) + self.assertEqual("my_network_1", net.second.name) net2 = ParentNetwork() net2(one) - self.assertStartsWith(expected_start="parent_network_2/my_network_1/dense", + self.assertStartsWith(expected_start="parent_network_1/my_network/dense", actual=net2.trainable_weights[0].name) - self.assertStartsWith(expected_start="parent_network_2/my_network_1/dense", + self.assertStartsWith(expected_start="parent_network_1/my_network/dense", actual=net2.first.trainable_weights[0].name) - self.assertStartsWith(expected_start="parent_network_2/my_network_2/dense", + self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense", actual=net2.trainable_weights[1].name) - self.assertStartsWith(expected_start="parent_network_2/my_network_2/dense", + self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense", actual=net2.second.trainable_weights[0].name) - self.assertEqual("parent_network_2", net2.name) - self.assertEqual("my_network_1", net2.first.name) - self.assertEqual("my_network_2", net2.second.name) + self.assertEqual("parent_network_1", net2.name) + self.assertEqual("my_network", net2.first.name) + self.assertEqual("my_network_1", net2.second.name) @test_util.run_in_graph_and_eager_modes() def testNestableExplicit(self): @@ -692,26 +778,26 @@ class NetworkTest(test.TestCase): one = constant_op.constant([[1.]]) net = MixedLayerNetwork() net(one) - self.assertEqual("dense_1", net.first.name) - self.assertEqual("dense_2", net.second.name) - self.assertEqual("dense_3", net.third.name) - self.assertEqual("dense_4", net.fourth.name) - self.assertEqual("dense_5", net.fifth.name) + self.assertEqual("dense", net.first.name) + self.assertEqual("dense_1", net.second.name) + self.assertEqual("dense_2", net.third.name) + self.assertEqual("dense_3", net.fourth.name) + self.assertEqual("dense_4", net.fifth.name) # Note that this is _not_ the default naming behavior for Layers. Layers # which are added to Networks follow Network variable naming conventions # (i.e. variable names = network name unless variable sharing). Nested # Layers revert to Layer behavior. - self.assertStartsWith(expected_start="mixed_layer_network_1/dense_1/", + self.assertStartsWith(expected_start="mixed_layer_network/dense/", actual=net.trainable_weights[0].name) - self.assertStartsWith(expected_start="mixed_layer_network_1/dense_2/", + self.assertStartsWith(expected_start="mixed_layer_network/dense_1/", actual=net.trainable_weights[1].name) - self.assertStartsWith(expected_start="mixed_layer_network_1/dense_3/", + self.assertStartsWith(expected_start="mixed_layer_network/dense_2/", actual=net.trainable_weights[2].name) - self.assertStartsWith(expected_start="mixed_layer_network_1/dense_4/", + self.assertStartsWith(expected_start="mixed_layer_network/dense_3/", actual=net.trainable_weights[3].name) - self.assertStartsWith(expected_start="mixed_layer_network_1/dense_5/", + self.assertStartsWith(expected_start="mixed_layer_network/dense_4/", actual=net.trainable_weights[4].name) - self.assertEqual("mixed_layer_network_1", net.name) + self.assertEqual("mixed_layer_network", net.name) @test_util.run_in_graph_and_eager_modes() def testNestableExplicitCollisions(self): @@ -764,24 +850,24 @@ class NetworkTest(test.TestCase): net = ParentNetwork() net(one) self.assertStartsWith( - expected_start="parent_network_1/first_unique_child_name/dense_1/", + expected_start="parent_network/first_unique_child_name/dense/", actual=net.trainable_weights[0].name) self.assertStartsWith( - expected_start="parent_network_1/second_unique_child_name/dense_1/", + expected_start="parent_network/second_unique_child_name/dense/", actual=net.trainable_weights[1].name) - self.assertEqual("parent_network_1", net.name) + self.assertEqual("parent_network", net.name) self.assertEqual("first_unique_child_name", net.first.name) self.assertEqual("second_unique_child_name", net.second.name) net2 = ParentNetwork() net2(one) self.assertStartsWith( - expected_start="parent_network_2/first_unique_child_name/dense", + expected_start="parent_network_1/first_unique_child_name/dense", actual=net2.trainable_weights[0].name) self.assertStartsWith( - expected_start="parent_network_2/second_unique_child_name/dense", + expected_start="parent_network_1/second_unique_child_name/dense", actual=net2.trainable_weights[1].name) - self.assertEqual("parent_network_2", net2.name) + self.assertEqual("parent_network_1", net2.name) self.assertEqual("first_unique_child_name", net2.first.name) self.assertEqual("second_unique_child_name", net2.second.name) @@ -839,15 +925,15 @@ class NetworkTest(test.TestCase): net2(one) self.assertStartsWith( - expected_start="first_parent_network_1/my_network_1/dense_1/", + expected_start="first_parent_network/my_network/dense/", actual=net2.trainable_weights[0].name) self.assertStartsWith( - expected_start="second_parent_network_1/my_network_1/dense_1/", + expected_start="second_parent_network/my_network/dense/", actual=net2.trainable_weights[1].name) - self.assertEqual("second_parent_network_1", net2.name) + self.assertEqual("second_parent_network", net2.name) self.assertTrue(net2.first is net.first) - self.assertEqual("my_network_1", net2.first.name) - self.assertEqual("my_network_1", net2.second.name) + self.assertEqual("my_network", net2.first.name) + self.assertEqual("my_network", net2.second.name) # No name collision; the owned Network is added first and has a different # name than the shared Network. @@ -865,15 +951,15 @@ class NetworkTest(test.TestCase): net3(one) self.assertStartsWith( - expected_start="third_parent_network_1/my_network_1/dense", + expected_start="third_parent_network/my_network/dense", actual=net3.trainable_weights[0].name) self.assertStartsWith( - expected_start="first_parent_network_1/my_network_2/dense", + expected_start="first_parent_network/my_network_1/dense", actual=net3.trainable_weights[1].name) - self.assertEqual("third_parent_network_1", net3.name) + self.assertEqual("third_parent_network", net3.name) self.assertTrue(net3.second is net.second) - self.assertEqual("my_network_1", net3.first.name) - self.assertEqual("my_network_2", net3.second.name) + self.assertEqual("my_network", net3.first.name) + self.assertEqual("my_network_1", net3.second.name) # "Unavoidable" same-name Layer. The owned name is added first (fixed), then # a shared Network is added with the same name. @@ -891,15 +977,15 @@ class NetworkTest(test.TestCase): net4(one) self.assertStartsWith( - expected_start="fourth_parent_network_1/my_network_1/dense_1/", + expected_start="fourth_parent_network/my_network/dense/", actual=net4.trainable_weights[0].name) self.assertStartsWith( - expected_start="first_parent_network_1/my_network_1/dense_1/", + expected_start="first_parent_network/my_network/dense/", actual=net4.trainable_weights[1].name) - self.assertEqual("fourth_parent_network_1", net4.name) + self.assertEqual("fourth_parent_network", net4.name) self.assertTrue(net4.second is net.first) - self.assertEqual("my_network_1", net4.first.name) - self.assertEqual("my_network_1", net4.second.name) + self.assertEqual("my_network", net4.first.name) + self.assertEqual("my_network", net4.second.name) @test_util.run_in_graph_and_eager_modes() def testRecursiveLayerRenaming(self): @@ -930,28 +1016,28 @@ class NetworkTest(test.TestCase): net(one) self.assertStartsWith( - expected_start=("parent_network_1/network_with_layer_children_1/" - "dense_1/"), + expected_start=("parent_network/network_with_layer_children/" + "dense/"), actual=net.trainable_weights[0].name) self.assertStartsWith( - expected_start=("parent_network_1/network_with_layer_children_1/" - "dense_2/"), + expected_start=("parent_network/network_with_layer_children/" + "dense_1/"), actual=net.trainable_weights[1].name) self.assertStartsWith( - expected_start=("parent_network_1/network_with_layer_children_2/" - "dense_1/"), + expected_start=("parent_network/network_with_layer_children_1/" + "dense/"), actual=net.trainable_weights[2].name) self.assertStartsWith( - expected_start=("parent_network_1/network_with_layer_children_2/" - "dense_2/"), + expected_start=("parent_network/network_with_layer_children_1/" + "dense_1/"), actual=net.trainable_weights[3].name) - self.assertEqual("parent_network_1", net.name) - self.assertEqual("network_with_layer_children_1", net.first.name) - self.assertEqual("network_with_layer_children_2", net.second.name) - self.assertEqual("dense_1", net.first.first.name) - self.assertEqual("dense_2", net.first.second.name) - self.assertEqual("dense_1", net.second.first.name) - self.assertEqual("dense_2", net.second.second.name) + self.assertEqual("parent_network", net.name) + self.assertEqual("network_with_layer_children", net.first.name) + self.assertEqual("network_with_layer_children_1", net.second.name) + self.assertEqual("dense", net.first.first.name) + self.assertEqual("dense_1", net.first.second.name) + self.assertEqual("dense", net.second.first.name) + self.assertEqual("dense_1", net.second.second.name) @test_util.run_in_graph_and_eager_modes() def testCallInDifferentOrderThanConstruct(self): @@ -985,23 +1071,23 @@ class NetworkTest(test.TestCase): net1(one) self.assertStartsWith( - expected_start="first_network_1/my_network_1/dense_1/", + expected_start="first_network/my_network/dense/", actual=net1.trainable_weights[0].name) self.assertStartsWith( - expected_start="first_network_1/my_network_2/dense_1/", + expected_start="first_network/my_network_1/dense/", actual=net1.trainable_weights[1].name) self.assertStartsWith( - expected_start="first_network_1/my_network_1/dense_1/", + expected_start="first_network/my_network/dense/", actual=net2.trainable_weights[0].name) self.assertStartsWith( - expected_start="second_network_1/my_network_1/dense_1/", + expected_start="second_network/my_network/dense/", actual=net2.trainable_weights[1].name) self.assertTrue(net1.trainable_weights[0] is net2.trainable_weights[0]) - self.assertEqual("first_network_1", net1.name) - self.assertEqual("my_network_1", net1.first.name) - self.assertEqual("my_network_2", net1.second.name) + self.assertEqual("first_network", net1.name) + self.assertEqual("my_network", net1.first.name) + self.assertEqual("my_network_1", net1.second.name) self.assertTrue(net2.first is net1.first) - self.assertEqual("my_network_1", net2.second.name) + self.assertEqual("my_network", net2.second.name) @test_util.run_in_graph_and_eager_modes() def testLayerCallInDifferentOrderThanConstruct(self): @@ -1038,23 +1124,23 @@ class NetworkTest(test.TestCase): net1(one) self.assertStartsWith( - expected_start="first_network_1/dense_1/", + expected_start="first_network/dense/", actual=net1.trainable_weights[0].name) self.assertStartsWith( - expected_start="first_network_1/dense_2/", + expected_start="first_network/dense_1/", actual=net1.trainable_weights[1].name) self.assertStartsWith( - expected_start="first_network_1/dense_1/", + expected_start="first_network/dense/", actual=net2.trainable_weights[0].name) self.assertStartsWith( - expected_start="second_network_1/dense_1/", + expected_start="second_network/dense/", actual=net2.trainable_weights[1].name) self.assertTrue(net1.trainable_weights[0] is net2.trainable_weights[0]) - self.assertEqual("first_network_1", net1.name) - self.assertEqual("dense_1", net1.first.name) - self.assertEqual("dense_2", net1.second.name) + self.assertEqual("first_network", net1.name) + self.assertEqual("dense", net1.first.name) + self.assertEqual("dense_1", net1.second.name) self.assertTrue(net2.first is net1.first) - self.assertEqual("dense_1", net2.second.name) + self.assertEqual("dense", net2.second.name) @test_util.run_in_graph_and_eager_modes() def testLayerAlreadyBuilt(self): @@ -1083,13 +1169,13 @@ class NetworkTest(test.TestCase): # do not match their layer names. actual=net.trainable_weights[0].name) self.assertStartsWith( - expected_start="first_network_1/dense_1/", + expected_start="first_network/dense/", actual=net.trainable_weights[1].name) self.assertTrue( net.trainable_weights[0] is shared_layer.trainable_weights[0]) - self.assertEqual("first_network_1", net.name) + self.assertEqual("first_network", net.name) self.assertEqual("dense_3", net.first.name) - self.assertEqual("dense_1", net.second.name) + self.assertEqual("dense", net.second.name) class SequentialTest(test.TestCase): diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index b6c687c8294..1697c879def 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -30,9 +30,6 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@value_and_gradients_function @@GradientTape -@@enable_tracing -@@flush_trace - @@run @@enable_eager_execution @@ -46,13 +43,16 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@seterr @@Iterator -@@Network @@Saver @@restore_variables_on_create @@Variable @@get_optimizer_variables @@EagerVariableStore +@@Network +@@save_network_checkpoint +@@restore_network_checkpoint + @@in_eager_mode @@in_graph_mode @@ -74,6 +74,8 @@ from __future__ import print_function from tensorflow.contrib.eager.python import metrics from tensorflow.contrib.eager.python.datasets import Iterator from tensorflow.contrib.eager.python.network import Network +from tensorflow.contrib.eager.python.network import save_network_checkpoint +from tensorflow.contrib.eager.python.network import restore_network_checkpoint from tensorflow.contrib.eager.python.saver import get_optimizer_variables from tensorflow.contrib.eager.python.saver import restore_variables_on_create from tensorflow.contrib.eager.python.saver import Saver @@ -86,7 +88,6 @@ from tensorflow.python.eager.context import in_eager_mode from tensorflow.python.eager.context import in_graph_mode from tensorflow.python.eager.context import list_devices from tensorflow.python.eager.context import num_gpus -from tensorflow.python.eager.core import enable_tracing from tensorflow.python.eager.custom_gradient import custom_gradient from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index bc67ef83541..008ca7a5d17 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -208,6 +208,7 @@ py_library( "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", "//tensorflow/python:summary", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:metric_keys", diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py index 73bae5acf9c..f2a6eae03ec 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary import summary @@ -342,14 +343,19 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access predictions = {} metrics = {} losses = [] - for head, spec in zip(self._heads, all_estimator_spec): - losses.append(spec.loss) - head_name = head.name - # Metric keys already contain head.name. - metrics.update(spec.eval_metric_ops or {}) - for k, v in six.iteritems(spec.predictions): - predictions[(head_name, k)] = v - loss = _merge_losses(losses, self._head_weights) + with ops.name_scope('merge_eval'): + for head, spec in zip(self._heads, all_estimator_spec): + losses.append(spec.loss) + head_name = head.name + # Loss metric is not added by default. + loss_name = head_lib._summary_key( # pylint:disable=protected-access + head_name, metric_keys.MetricKeys.LOSS) + metrics[loss_name] = metrics_lib.mean(spec.loss, name=loss_name) + # Metric keys already contain head.name. + metrics.update(spec.eval_metric_ops or {}) + for k, v in six.iteritems(spec.predictions): + predictions[(head_name, k)] = v + loss = _merge_losses(losses, self._head_weights) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index 8d51a298b23..68f2d5d1cd5 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -297,6 +297,8 @@ class MultiHeadTest(test.TestCase): keys = metric_keys.MetricKeys expected_metrics = { + keys.LOSS + '/head1': expected_loss_head1, + keys.LOSS + '/head2': expected_loss_head2, # Average loss over examples. keys.LOSS_MEAN + '/head1': expected_loss_head1 / 2, keys.LOSS_MEAN + '/head2': expected_loss_head2 / 2, diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py index 3976395d78e..b2f22eb2fce 100644 --- a/tensorflow/contrib/factorization/python/ops/wals.py +++ b/tensorflow/contrib/factorization/python/ops/wals.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.factorization.python.ops import factorization_ops -from tensorflow.contrib.framework.python.ops import variables as framework_variables from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.python.framework import dtypes @@ -32,175 +31,64 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import session_run_hook +from tensorflow.python.training import training_util class _SweepHook(session_run_hook.SessionRunHook): """Keeps track of row/col sweeps, and runs prep ops before each sweep.""" - def __init__(self, is_row_sweep_var, train_ops, num_rows, num_cols, - input_row_indices, input_col_indices, row_prep_ops, - col_prep_ops, init_op, completed_sweeps_var): + def __init__(self, is_row_sweep_var, is_sweep_done_var, init_op, + row_prep_ops, col_prep_ops, row_train_op, col_train_op, + switch_op): """Initializes SweepHook. Args: is_row_sweep_var: A Boolean tf.Variable, determines whether we are currently doing a row or column sweep. It is updated by the hook. - train_ops: A list of ops. The ops created by this hook will have - control dependencies on `train_ops`. - num_rows: int, the total number of rows to be processed. - num_cols: int, the total number of columns to be processed. - input_row_indices: A Tensor of type int64. The indices of the input rows - that are processed during the current sweep. All elements of - `input_row_indices` must be in [0, num_rows). - input_col_indices: A Tensor of type int64. The indices of the input - columns that are processed during the current sweep. All elements of - `input_col_indices` must be in [0, num_cols). - row_prep_ops: list of ops, to be run before the beginning of each row - sweep, in the given order. - col_prep_ops: list of ops, to be run before the beginning of each column - sweep, in the given order. + is_sweep_done_var: A Boolean tf.Variable, determines whether we are + starting a new sweep (this is used to determine when to run the prep ops + below). init_op: op to be run once before training. This is typically a local initialization op (such as cache initialization). - completed_sweeps_var: An integer tf.Variable, indicates the number of - completed sweeps. It is updated by the hook. + row_prep_ops: A list of TensorFlow ops, to be run before the beginning of + each row sweep (and during initialization), in the given order. + col_prep_ops: A list of TensorFlow ops, to be run before the beginning of + each column sweep (and during initialization), in the given order. + row_train_op: A TensorFlow op to be run during row sweeps. + col_train_op: A TensorFlow op to be run during column sweeps. + switch_op: A TensorFlow op to be run before each sweep. """ - self._num_rows = num_rows - self._num_cols = num_cols + self._is_row_sweep_var = is_row_sweep_var + self._is_sweep_done_var = is_sweep_done_var + self._init_op = init_op self._row_prep_ops = row_prep_ops self._col_prep_ops = col_prep_ops - self._init_op = init_op - self._is_row_sweep_var = is_row_sweep_var - self._completed_sweeps_var = completed_sweeps_var - # Boolean variable that determines whether the init_ops have been run. + self._row_train_op = row_train_op + self._col_train_op = col_train_op + self._switch_op = switch_op + # Boolean variable that determines whether the init_op has been run. self._is_initialized = False - # Ops to run jointly with train_ops, responsible for updating - # `is_row_sweep_var` and incrementing the `global_step` and - # `completed_sweeps` counters. - self._update_op, self._is_sweep_done_var, self._switch_op = ( - self._create_hook_ops(input_row_indices, input_col_indices, train_ops)) - - def _create_hook_ops(self, input_row_indices, input_col_indices, train_ops): - """Creates ops to update is_row_sweep_var, global_step and completed_sweeps. - - Creates two boolean tensors `processed_rows` and `processed_cols`, which - keep track of which rows/cols have been processed during the current sweep. - Returns ops that should be run after each row / col update. - - When `self._is_row_sweep_var` is True, it sets - processed_rows[input_row_indices] to True. - - When `self._is_row_sweep_var` is False, it sets - processed_cols[input_col_indices] to True. - - Args: - input_row_indices: A Tensor. The indices of the input rows that are - processed during the current sweep. - input_col_indices: A Tensor. The indices of the input columns that - are processed during the current sweep. - train_ops: A list of ops. The ops created by this function have control - dependencies on `train_ops`. - - Returns: - A tuple consisting of: - update_op: An op to be run jointly with training. It updates the state - and increments counters (global step and completed sweeps). - is_sweep_done_var: A Boolean tf.Variable, specifies whether the sweep is - done, i.e. all rows (during a row sweep) or all columns (during a - column sweep) have been processed. - switch_op: An op to be run in `self.before_run` when the sweep is done. - """ - processed_rows_init = array_ops.fill(dims=[self._num_rows], value=False) - with ops.colocate_with(processed_rows_init): - processed_rows = variable_scope.variable( - processed_rows_init, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - trainable=False, - name="sweep_hook_processed_rows") - processed_cols_init = array_ops.fill(dims=[self._num_cols], value=False) - with ops.colocate_with(processed_cols_init): - processed_cols = variable_scope.variable( - processed_cols_init, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - trainable=False, - name="sweep_hook_processed_cols") - switch_ops = control_flow_ops.group( - state_ops.assign( - self._is_row_sweep_var, - math_ops.logical_not(self._is_row_sweep_var)), - state_ops.assign(processed_rows, processed_rows_init), - state_ops.assign(processed_cols, processed_cols_init)) - is_sweep_done_var = variable_scope.variable( - False, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - trainable=False, - name="is_sweep_done") - - # After running the `train_ops`, updates `processed_rows` or - # `processed_cols` tensors, depending on whether this is a row or col sweep. - with ops.control_dependencies(train_ops): - with ops.colocate_with(processed_rows): - update_processed_rows = state_ops.scatter_update( - processed_rows, - input_row_indices, - math_ops.logical_and( - self._is_row_sweep_var, - array_ops.ones_like(input_row_indices, dtype=dtypes.bool))) - with ops.colocate_with(processed_cols): - update_processed_cols = state_ops.scatter_update( - processed_cols, - input_col_indices, - math_ops.logical_and( - math_ops.logical_not(self._is_row_sweep_var), - array_ops.ones_like(input_col_indices, dtype=dtypes.bool))) - update_processed_op = control_flow_ops.group( - update_processed_rows, update_processed_cols) - - with ops.control_dependencies([update_processed_op]): - is_sweep_done = math_ops.logical_or( - math_ops.reduce_all(processed_rows), - math_ops.reduce_all(processed_cols)) - # Increments global step. - global_step = framework_variables.get_global_step() - if global_step is not None: - global_step_incr_op = state_ops.assign_add( - global_step, 1, name="global_step_incr").op - else: - global_step_incr_op = control_flow_ops.no_op() - # Increments completed sweeps. - completed_sweeps_incr_op = state_ops.assign_add( - self._completed_sweeps_var, - math_ops.cast(is_sweep_done, dtypes.int32), - use_locking=True).op - update_ops = control_flow_ops.group( - global_step_incr_op, - completed_sweeps_incr_op, - state_ops.assign(is_sweep_done_var, is_sweep_done)) - - return update_ops, is_sweep_done_var, switch_ops def before_run(self, run_context): """Runs the appropriate prep ops, and requests running update ops.""" - # Runs the appropriate init ops and prep ops. sess = run_context.session is_sweep_done = sess.run(self._is_sweep_done_var) if not self._is_initialized: - logging.info("SweepHook running cache init op.") + logging.info("SweepHook running init op.") sess.run(self._init_op) if is_sweep_done: sess.run(self._switch_op) + is_row_sweep = sess.run(self._is_row_sweep_var) if is_sweep_done or not self._is_initialized: - logging.info("SweepHook running sweep prep ops.") - row_sweep = sess.run(self._is_row_sweep_var) - prep_ops = self._row_prep_ops if row_sweep else self._col_prep_ops + logging.info("SweepHook running prep ops for the {} sweep.".format( + "row" if is_row_sweep else "col")) + prep_ops = self._row_prep_ops if is_row_sweep else self._col_prep_ops for prep_op in prep_ops: sess.run(prep_op) - self._is_initialized = True - - # Requests running `self._update_op` jointly with the training op. logging.info("Next fit step starting.") - return session_run_hook.SessionRunArgs(fetches=[self._update_op]) - - def after_run(self, run_context, run_values): - logging.info("Fit step done.") + return session_run_hook.SessionRunArgs( + fetches=[self._row_train_op if is_row_sweep else self._col_train_op]) class _StopAtSweepHook(session_run_hook.SessionRunHook): @@ -246,6 +134,9 @@ def _wals_factorization_model_function(features, labels, mode, params): Returns: A ModelFnOps object. + + Raises: + ValueError: If `mode` is not recognized. """ assert labels is None use_factors_weights_cache = (params["use_factors_weights_cache_for_training"] @@ -269,86 +160,156 @@ def _wals_factorization_model_function(features, labels, mode, params): use_gramian_cache=use_gramian_cache) # Get input rows and cols. We either update rows or columns depending on - # the value of row_sweep, which is maintained using a session hook + # the value of row_sweep, which is maintained using a session hook. input_rows = features[WALSMatrixFactorization.INPUT_ROWS] input_cols = features[WALSMatrixFactorization.INPUT_COLS] - input_row_indices, _ = array_ops.unique(input_rows.indices[:, 0]) - input_col_indices, _ = array_ops.unique(input_cols.indices[:, 0]) - # Train ops, controlled using the SweepHook - # We need to run the following ops: - # Before a row sweep: - # row_update_prep_gramian_op - # initialize_row_update_op - # During a row sweep: - # update_row_factors_op - # Before a col sweep: - # col_update_prep_gramian_op - # initialize_col_update_op - # During a col sweep: - # update_col_factors_op + # TRAIN mode: + if mode == model_fn.ModeKeys.TRAIN: + # Training consists of the folowing ops (controlled using a SweepHook). + # Before a row sweep: + # row_update_prep_gramian_op + # initialize_row_update_op + # During a row sweep: + # update_row_factors_op + # Before a col sweep: + # col_update_prep_gramian_op + # initialize_col_update_op + # During a col sweep: + # update_col_factors_op - is_row_sweep_var = variable_scope.variable( - True, - trainable=False, - name="is_row_sweep", - collections=[ops.GraphKeys.GLOBAL_VARIABLES]) - completed_sweeps_var = variable_scope.variable( - 0, - trainable=False, - name=WALSMatrixFactorization.COMPLETED_SWEEPS, - collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + is_row_sweep_var = variable_scope.variable( + True, + trainable=False, + name="is_row_sweep", + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + is_sweep_done_var = variable_scope.variable( + False, + trainable=False, + name="is_sweep_done", + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + completed_sweeps_var = variable_scope.variable( + 0, + trainable=False, + name=WALSMatrixFactorization.COMPLETED_SWEEPS, + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + loss_var = variable_scope.variable( + 0., + trainable=False, + name=WALSMatrixFactorization.LOSS, + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + # The root weighted squared error = + # \sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij ) + rwse_var = variable_scope.variable( + 0., + trainable=False, + name=WALSMatrixFactorization.RWSE, + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) - # The row sweep is determined by is_row_sweep_var (controlled by the - # sweep_hook) in TRAIN mode, and manually in EVAL mode. - is_row_sweep = (features[WALSMatrixFactorization.PROJECT_ROW] - if mode == model_fn.ModeKeys.EVAL else is_row_sweep_var) + summary.scalar("loss", loss_var) + summary.scalar("root_weighted_squared_error", rwse_var) + summary.scalar("completed_sweeps", completed_sweeps_var) - def update_row_factors(): - return model.update_row_factors(sp_input=input_rows, transpose_input=False) + # Increments global step. + global_step = training_util.get_global_step() + if global_step: + global_step_incr_op = state_ops.assign_add( + global_step, 1, name="global_step_incr").op + else: + global_step_incr_op = control_flow_ops.no_op() - def update_col_factors(): - return model.update_col_factors(sp_input=input_cols, transpose_input=True) + def create_axis_ops(sp_input, num_items, update_fn, axis_name): + """Creates book-keeping and training ops for a given axis. - (_, train_op, - unregularized_loss, regularization, sum_weights) = control_flow_ops.cond( - is_row_sweep, update_row_factors, update_col_factors) - loss = unregularized_loss + regularization - root_weighted_squared_error = math_ops.sqrt(unregularized_loss / sum_weights) + Args: + sp_input: A SparseTensor corresponding to the row or column batch. + num_items: An integer, the total number of items of this axis. + update_fn: A function that takes one argument (`sp_input`), and that + returns a tuple of + * new_factors: A flot Tensor of the factor values after update. + * update_op: a TensorFlow op which updates the factors. + * loss: A float Tensor, the unregularized loss. + * reg_loss: A float Tensor, the regularization loss. + * sum_weights: A float Tensor, the sum of factor weights. + axis_name: A string that specifies the name of the axis. - row_prep_ops = [ - model.row_update_prep_gramian_op, model.initialize_row_update_op - ] - col_prep_ops = [ - model.col_update_prep_gramian_op, model.initialize_col_update_op - ] - init_ops = [model.worker_init] + Returns: + A tuple consisting of: + * reset_processed_items_op: A TensorFlow op, to be run before the + beginning of any sweep. It marks all items as not-processed. + * axis_train_op: A Tensorflow op, to be run during this axis' sweeps. + """ + processed_items_init = array_ops.fill(dims=[num_items], value=False) + with ops.colocate_with(processed_items_init): + processed_items = variable_scope.variable( + processed_items_init, + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + trainable=False, + name="processed_" + axis_name) + reset_processed_items_op = state_ops.assign( + processed_items, processed_items_init, + name="reset_processed_" + axis_name) + _, update_op, loss, reg, sum_weights = update_fn(sp_input) + input_indices = sp_input.indices[:, 0] + with ops.control_dependencies([ + update_op, + state_ops.assign(loss_var, loss + reg), + state_ops.assign(rwse_var, math_ops.sqrt(loss / sum_weights))]): + with ops.colocate_with(processed_items): + update_processed_items = state_ops.scatter_update( + processed_items, + input_indices, + array_ops.ones_like(input_indices, dtype=dtypes.bool), + name="update_processed_{}_indices".format(axis_name)) + with ops.control_dependencies([update_processed_items]): + is_sweep_done = math_ops.reduce_all(processed_items) + axis_train_op = control_flow_ops.group( + global_step_incr_op, + state_ops.assign(is_sweep_done_var, is_sweep_done), + state_ops.assign_add( + completed_sweeps_var, + math_ops.cast(is_sweep_done, dtypes.int32)), + name="{}_sweep_train_op".format(axis_name)) + return reset_processed_items_op, axis_train_op - sweep_hook = _SweepHook( - is_row_sweep_var, - [train_op, loss], - params["num_rows"], - params["num_cols"], - input_row_indices, - input_col_indices, - row_prep_ops, - col_prep_ops, - init_ops, - completed_sweeps_var) - training_hooks = [sweep_hook] - if max_sweeps is not None: - training_hooks.append(_StopAtSweepHook(max_sweeps)) + reset_processed_rows_op, row_train_op = create_axis_ops( + input_rows, + params["num_rows"], + lambda x: model.update_row_factors(sp_input=x, transpose_input=False), + "rows") + reset_processed_cols_op, col_train_op = create_axis_ops( + input_cols, + params["num_cols"], + lambda x: model.update_col_factors(sp_input=x, transpose_input=True), + "cols") + switch_op = control_flow_ops.group( + state_ops.assign( + is_row_sweep_var, math_ops.logical_not(is_row_sweep_var)), + reset_processed_rows_op, + reset_processed_cols_op, + name="sweep_switch_op") + row_prep_ops = [ + model.row_update_prep_gramian_op, model.initialize_row_update_op] + col_prep_ops = [ + model.col_update_prep_gramian_op, model.initialize_col_update_op] + init_op = model.worker_init + sweep_hook = _SweepHook( + is_row_sweep_var, is_sweep_done_var, init_op, + row_prep_ops, col_prep_ops, row_train_op, col_train_op, switch_op) + training_hooks = [sweep_hook] + if max_sweeps is not None: + training_hooks.append(_StopAtSweepHook(max_sweeps)) - # The root weighted squared error = - # \sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij ) - summary.scalar("loss", loss) # the estimated total training loss - summary.scalar("root_weighted_squared_error", root_weighted_squared_error) - summary.scalar("completed_sweeps", completed_sweeps_var) + return model_fn.ModelFnOps( + mode=model_fn.ModeKeys.TRAIN, + predictions={}, + loss=loss_var, + eval_metric_ops={}, + train_op=control_flow_ops.no_op(), + training_hooks=training_hooks) - # Prediction ops (only return predictions in INFER mode) - predictions = {} - if mode == model_fn.ModeKeys.INFER: - project_row = features[WALSMatrixFactorization.PROJECT_ROW] + # INFER mode + elif mode == model_fn.ModeKeys.INFER: projection_weights = features.get( WALSMatrixFactorization.PROJECTION_WEIGHTS) @@ -364,17 +325,45 @@ def _wals_factorization_model_function(features, labels, mode, params): projection_weights=projection_weights, transpose_input=True) - predictions[WALSMatrixFactorization.PROJECTION_RESULT] = ( - control_flow_ops.cond(project_row, get_row_projection, - get_col_projection)) + predictions = { + WALSMatrixFactorization.PROJECTION_RESULT: control_flow_ops.cond( + features[WALSMatrixFactorization.PROJECT_ROW], + get_row_projection, + get_col_projection) + } - return model_fn.ModelFnOps( - mode=mode, - predictions=predictions, - loss=loss, - eval_metric_ops={}, - train_op=train_op, - training_hooks=training_hooks) + return model_fn.ModelFnOps( + mode=model_fn.ModeKeys.INFER, + predictions=predictions, + loss=None, + eval_metric_ops={}, + train_op=control_flow_ops.no_op(), + training_hooks=[]) + + # EVAL mode + elif mode == model_fn.ModeKeys.EVAL: + def get_row_loss(): + _, _, loss, reg, _ = model.update_row_factors( + sp_input=input_rows, transpose_input=False) + return loss + reg + def get_col_loss(): + _, _, loss, reg, _ = model.update_col_factors( + sp_input=input_cols, transpose_input=True) + return loss + reg + loss = control_flow_ops.cond( + features[WALSMatrixFactorization.PROJECT_ROW], + get_row_loss, + get_col_loss) + return model_fn.ModelFnOps( + mode=model_fn.ModeKeys.EVAL, + predictions={}, + loss=loss, + eval_metric_ops={}, + train_op=control_flow_ops.no_op(), + training_hooks=[]) + + else: + raise ValueError("mode=%s is not recognized." % str(mode)) class WALSMatrixFactorization(estimator.Estimator): @@ -452,6 +441,10 @@ class WALSMatrixFactorization(estimator.Estimator): PROJECTION_RESULT = "projection" # Name of the completed_sweeps variable COMPLETED_SWEEPS = "completed_sweeps" + # Name of the loss variable + LOSS = "WALS_loss" + # Name of the Root Weighted Squared Error variable + RWSE = "WALS_RWSE" def __init__(self, num_rows, diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py index 8bd72b7025a..36b483c6d7a 100644 --- a/tensorflow/contrib/factorization/python/ops/wals_test.py +++ b/tensorflow/contrib/factorization/python/ops/wals_test.py @@ -417,73 +417,67 @@ class WALSMatrixFactorizationUnsupportedTest(test.TestCase): class SweepHookTest(test.TestCase): - def setUp(self): - self._num_rows = 5 - self._num_cols = 7 - self._train_op = control_flow_ops.no_op() - self._row_prep_done = variables.Variable(False) - self._col_prep_done = variables.Variable(False) - self._init_done = variables.Variable(False) - self._row_prep_ops = [state_ops.assign(self._row_prep_done, True)] - self._col_prep_ops = [state_ops.assign(self._col_prep_done, True)] - self._init_ops = [state_ops.assign(self._init_done, True)] - self._input_row_indices_ph = array_ops.placeholder(dtypes.int64) - self._input_col_indices_ph = array_ops.placeholder(dtypes.int64) - def test_sweeps(self): - def ind_feed(row_indices, col_indices): - return { - self._input_row_indices_ph: row_indices, - self._input_col_indices_ph: col_indices - } + is_row_sweep_var = variables.Variable(True) + is_sweep_done_var = variables.Variable(False) + init_done = variables.Variable(False) + row_prep_done = variables.Variable(False) + col_prep_done = variables.Variable(False) + row_train_done = variables.Variable(False) + col_train_done = variables.Variable(False) + + init_op = state_ops.assign(init_done, True) + row_prep_op = state_ops.assign(row_prep_done, True) + col_prep_op = state_ops.assign(col_prep_done, True) + row_train_op = state_ops.assign(row_train_done, True) + col_train_op = state_ops.assign(col_train_done, True) + train_op = control_flow_ops.no_op() + switch_op = control_flow_ops.group( + state_ops.assign(is_sweep_done_var, False), + state_ops.assign(is_row_sweep_var, + math_ops.logical_not(is_row_sweep_var))) + mark_sweep_done = state_ops.assign(is_sweep_done_var, True) with self.test_session() as sess: - is_row_sweep_var = variables.Variable(True) - completed_sweeps_var = variables.Variable(0) sweep_hook = wals_lib._SweepHook( is_row_sweep_var, - [self._train_op], - self._num_rows, - self._num_cols, - self._input_row_indices_ph, - self._input_col_indices_ph, - self._row_prep_ops, - self._col_prep_ops, - self._init_ops, - completed_sweeps_var) + is_sweep_done_var, + init_op, + [row_prep_op], + [col_prep_op], + row_train_op, + col_train_op, + switch_op) mon_sess = monitored_session._HookedSession(sess, [sweep_hook]) sess.run([variables.global_variables_initializer()]) - # Init ops should run before the first run. Row sweep not completed. - mon_sess.run(self._train_op, ind_feed([0, 1, 2], [])) - self.assertTrue(sess.run(self._init_done), - msg='init ops not run by the sweep_hook') - self.assertTrue(sess.run(self._row_prep_done), - msg='row_prep not run by the sweep_hook') - self.assertTrue(sess.run(is_row_sweep_var), - msg='Row sweep is not complete but is_row_sweep is ' - 'False.') - # Row sweep completed. - mon_sess.run(self._train_op, ind_feed([3, 4], [0, 1, 2, 3, 4, 5, 6])) - self.assertTrue(sess.run(completed_sweeps_var) == 1, - msg='Completed sweeps should be equal to 1.') - self.assertTrue(sess.run(sweep_hook._is_sweep_done_var), - msg='Sweep is complete but is_sweep_done is False.') - # Col init ops should run. Col sweep not completed. - mon_sess.run(self._train_op, ind_feed([], [0, 1, 2, 3, 4])) - self.assertTrue(sess.run(self._col_prep_done), - msg='col_prep not run by the sweep_hook') - self.assertFalse(sess.run(is_row_sweep_var), - msg='Col sweep is not complete but is_row_sweep is ' - 'True.') - self.assertFalse(sess.run(sweep_hook._is_sweep_done_var), - msg='Sweep is not complete but is_sweep_done is True.') - # Col sweep completed. - mon_sess.run(self._train_op, ind_feed([], [4, 5, 6])) - self.assertTrue(sess.run(sweep_hook._is_sweep_done_var), - msg='Sweep is complete but is_sweep_done is False.') - self.assertTrue(sess.run(completed_sweeps_var) == 2, - msg='Completed sweeps should be equal to 2.') + # Row sweep. + mon_sess.run(train_op) + self.assertTrue(sess.run(init_done), + msg='init op not run by the Sweephook') + self.assertTrue(sess.run(row_prep_done), + msg='row_prep_op not run by the SweepHook') + self.assertTrue(sess.run(row_train_done), + msg='row_train_op not run by the SweepHook') + self.assertTrue( + sess.run(is_row_sweep_var), + msg='Row sweep is not complete but is_row_sweep_var is False.') + # Col sweep. + mon_sess.run(mark_sweep_done) + mon_sess.run(train_op) + self.assertTrue(sess.run(col_prep_done), + msg='col_prep_op not run by the SweepHook') + self.assertTrue(sess.run(col_train_done), + msg='col_train_op not run by the SweepHook') + self.assertFalse( + sess.run(is_row_sweep_var), + msg='Col sweep is not complete but is_row_sweep_var is True.') + # Row sweep. + mon_sess.run(mark_sweep_done) + mon_sess.run(train_op) + self.assertTrue( + sess.run(is_row_sweep_var), + msg='Col sweep is complete but is_row_sweep_var is False.') class StopAtSweepHookTest(test.TestCase): diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index 011ddeaa9a1..faedee6f877 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -224,7 +224,8 @@ def transform(images, transforms, interpolation="NEAREST", name=None): `(x, y)` to a transformed *input* point `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to - the transform mapping input points to output points. + the transform mapping input points to output points. Note that gradients + are not backpropagated into transformation parameters. interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". Returns: diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD index 60c245166d6..7d65ac9a43d 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -68,6 +68,7 @@ py_test( srcs = ["layer_collection_test.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_blocks", "//tensorflow/contrib/kfac/python/ops:fisher_factors", "//tensorflow/contrib/kfac/python/ops:layer_collection", "//tensorflow/python:array_ops", @@ -75,6 +76,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:random_seed", "//tensorflow/python:variable_scope", diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py index 524e8338fde..c5ad90d1dc7 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.kfac.python.ops import fisher_blocks from tensorflow.contrib.kfac.python.ops import fisher_factors from tensorflow.contrib.kfac.python.ops import layer_collection from tensorflow.python.framework import dtypes @@ -25,6 +26,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test @@ -105,8 +107,10 @@ class LayerCollectionTest(test.TestCase): array_ops.constant(4), [1, 1, 1, 1], 'SAME', array_ops.ones((1, 1, 1, 1)), array_ops.constant(3)) lc.register_conv2d( - array_ops.constant(4), [1, 1, 1, 1], 'SAME', - array_ops.ones((1, 1, 1, 1)), array_ops.constant(3), + array_ops.constant(4), [1, 1, 1, 1], + 'SAME', + array_ops.ones((1, 1, 1, 1)), + array_ops.constant(3), approx=layer_collection.APPROX_DIAGONAL_NAME) lc.register_generic( array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME) @@ -122,8 +126,8 @@ class LayerCollectionTest(test.TestCase): random_seed.set_random_seed(200) lc = layer_collection.LayerCollection() key = array_ops.constant(1) - lc.register_fully_connected(key, - array_ops.constant(2), array_ops.constant(3)) + lc.register_fully_connected(key, array_ops.constant(2), + array_ops.constant(3)) with self.assertRaises(ValueError): lc.register_generic(key, 16) @@ -191,8 +195,8 @@ class LayerCollectionTest(test.TestCase): lc.register_block((x, y), MockFisherBlock('foo')) self.assertEqual( - set([MockFisherBlock('2'), MockFisherBlock('foo')]), - set(lc.get_blocks())) + set([MockFisherBlock('2'), MockFisherBlock('foo')]), set( + lc.get_blocks())) def testRegisterTupleVarSomeRegisteredInOtherTuples(self): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) @@ -464,6 +468,66 @@ class LayerCollectionTest(test.TestCase): use_count_map = lc.get_use_count_map() self.assertDictEqual({'a': 4, 'b': 2, 'c': 4}, use_count_map) + def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self): + x = variable_scope.get_variable('x', shape=()) + y = variable_scope.get_variable('y', shape=()) + z = variable_scope.get_variable('z', shape=()) + lc = layer_collection.LayerCollection() + lc.define_linked_parameters((x, y)) + + with self.assertRaises(ValueError): + lc.define_linked_parameters((x, z)) + + def testIdentifySubsetPreviouslyRegisteredTensor(self): + x = variable_scope.get_variable('x', shape=()) + y = variable_scope.get_variable('y', shape=()) + lc = layer_collection.LayerCollection() + lc.define_linked_parameters((x, y)) + + with self.assertRaises(ValueError): + lc.define_linked_parameters(x) + + def testSpecifyApproximation(self): + w_0 = variable_scope.get_variable('w_0', [10, 10]) + w_1 = variable_scope.get_variable('w_1', [10, 10]) + + b_0 = variable_scope.get_variable('b_0', [10]) + b_1 = variable_scope.get_variable('b_1', [10]) + + x_0 = array_ops.placeholder(dtypes.float32, shape=(32, 10)) + x_1 = array_ops.placeholder(dtypes.float32, shape=(32, 10)) + + pre_bias_0 = math_ops.matmul(x_0, w_0) + pre_bias_1 = math_ops.matmul(x_1, w_1) + + # Build the fully connected layers in the graph. + pre_bias_0 + b_0 # pylint: disable=pointless-statement + pre_bias_1 + b_1 # pylint: disable=pointless-statement + + lc = layer_collection.LayerCollection() + lc.define_linked_parameters( + w_0, approximation=layer_collection.APPROX_DIAGONAL_NAME) + lc.define_linked_parameters( + w_1, approximation=layer_collection.APPROX_DIAGONAL_NAME) + lc.define_linked_parameters( + b_0, approximation=layer_collection.APPROX_FULL_NAME) + lc.define_linked_parameters( + b_1, approximation=layer_collection.APPROX_FULL_NAME) + + lc.register_fully_connected(w_0, x_0, pre_bias_0) + lc.register_fully_connected( + w_1, x_1, pre_bias_1, approx=layer_collection.APPROX_KRONECKER_NAME) + self.assertIsInstance(lc.fisher_blocks[w_0], + fisher_blocks.FullyConnectedDiagonalFB) + self.assertIsInstance(lc.fisher_blocks[w_1], + fisher_blocks.FullyConnectedKFACBasicFB) + + lc.register_generic(b_0, batch_size=1) + lc.register_generic( + b_1, batch_size=1, approx=layer_collection.APPROX_DIAGONAL_NAME) + self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB) + self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 7300a7998c2..2139a261e05 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -38,12 +38,26 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest - # Names for various approximations that can be requested for Fisher blocks. APPROX_KRONECKER_NAME = "kron" APPROX_DIAGONAL_NAME = "diagonal" APPROX_FULL_NAME = "full" +_GENERIC_APPROX_TO_BLOCK_TYPES = { + APPROX_FULL_NAME: fb.FullFB, + APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB, +} + +_FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB, + APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB, +} + +_CONV2D_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB, + APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, +} + # Possible value for 'reuse' keyword argument. Sets 'reuse' to # tf.get_variable_scope().reuse. VARIABLE_SCOPE = "VARIABLE_SCOPE" @@ -51,6 +65,14 @@ VARIABLE_SCOPE = "VARIABLE_SCOPE" # TODO(jamesmartens): need to add find_canonical_output back into this somewhere +def ensure_sequence(obj): + """If `obj` isn't a tuple or list, return a tuple containing `obj`.""" + if isinstance(obj, (tuple, list)): + return obj + else: + return (obj,) + + class LayerParametersDict(OrderedDict): """An OrderedDict where keys are Tensors or tuples of Tensors. @@ -110,9 +132,14 @@ class LayerCollection(object): def __init__(self, graph=None, name="LayerCollection"): self.fisher_blocks = LayerParametersDict() self.fisher_factors = OrderedDict() + self._linked_parameters = dict( + ) # dict mapping sets of variables to optionally specified approximations. self._graph = graph or ops.get_default_graph() self._loss_dict = {} # {str: LossFunction} self._subgraph = None + self._default_generic_approximation = APPROX_FULL_NAME + self._default_fully_connected_approximation = APPROX_KRONECKER_NAME + self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME with variable_scope.variable_scope(None, default_name=name) as scope: self._var_scope = scope.name @@ -122,6 +149,70 @@ class LayerCollection(object): """LossFunctions registered with this LayerCollection.""" return list(self._loss_dict.values()) + def is_variable_registered(self, variable): + """Checks whether the variable has already been registered. + + Args: + variable: A single variable or tensor. + Returns: + True if the variable has been registered either by itself or as part of a + tuple. + """ + return any([ + variable in key if isinstance(key, (tuple, list)) else variable == key + for key in self.fisher_blocks.keys() + ]) + + @property + def linked_parameters(self): + """Groups of parameters with an optionally specified approximation. + + Linked parameters can be added using `define_linked_parameters`. + If an approximation is specified, then this approximation will be used + when registering a layer with exactly these parameters, unless an + approximation is specified when calling the registration function. + + Returns: + A `dict` mapping tuples of parameters to an optional string. + """ + return self._linked_parameters + + @property + def default_generic_approximation(self): + return self._default_generic_approximation + + @default_generic_approximation.setter + def default_generic_approximation(self, value): + if value not in _GENERIC_APPROX_TO_BLOCK_TYPES: + raise ValueError( + "{} is not a valid approximation for generic variables.".format( + value)) + self._default_generic_approximation = value + + @property + def default_fully_connected_approximation(self): + return self._default_fully_connected_approximation + + @default_fully_connected_approximation.setter + def default_fully_connected_approximation(self, value): + if value not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES: + raise ValueError( + "{} is not a valid approximation for fully connected layers.".format( + value)) + self._default_fully_connected_approximation = value + + @property + def default_conv2d_approximation(self): + return self._default_convolution_2d_approximation + + @default_conv2d_approximation.setter + def default_conv2d_approximation(self, value): + if value not in _CONV2D_APPROX_TO_BLOCK_TYPES: + raise ValueError( + "{} is not a valid approximation for 2d convolutional layers.".format( + value)) + self._default_convolution_2d_approximation = value + def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE): """Validates and registers the layer_key associated with the fisher_block. @@ -187,7 +278,8 @@ class LayerCollection(object): # Find all keys that are either supersets or subsets of 'layer_key'. inclusions = { fisher_elt - for layer_elt in layer_key for fisher_elt in self.fisher_blocks + for layer_elt in layer_key + for fisher_elt in self.fisher_blocks if self._equal_or_subset(layer_elt, fisher_elt) } @@ -294,6 +386,49 @@ class LayerCollection(object): def subgraph(self): return self._subgraph + def define_linked_parameters(self, params, approximation=None): + """Identify a set of parameters that should be grouped together. + + During automatic graph scanning, any matches containing variables that have + been identified as part of a linked group will be filtered out unless + the match parameters are exactly equal to the ones specified in the linked + group. + + Args: + params: A variable, or a tuple or list of variables. The variables + to be linked. + approximation: Optional string specifying the type of approximation to use + for these variables. If unspecified, this layer collection's default + approximation for the layer type will be used. + + Raises: + ValueError: If the parameters were already registered in a layer or + identified as part of an incompatible group. + """ + params = frozenset(ensure_sequence(params)) + + # Check if any of the variables in 'params' is already in + # 'self.fisher_blocks.keys()'. + for registered_params, fisher_block in self.fisher_blocks.items(): + registered_params_set = set(ensure_sequence(registered_params)) + for variable in params: + if (variable in registered_params_set and + params != registered_params_set): + raise ValueError( + "Can't link parameters {}, variable {} was already registered in " + "group {} with layer {}".format(params, variable, + registered_params, fisher_block)) + + # Check if any of the variables in 'params' is already in + # 'self.linked_parameters'. + for variable in params: + for other_linked_params in self.linked_parameters: + if variable in other_linked_params: + raise ValueError("Can't link parameters {}, variable {} was already " + "linked in group {}.".format(params, variable, + other_linked_params)) + self._linked_parameters[params] = approximation + def create_subgraph(self): if not self.losses: raise ValueError("Must have at least one registered loss.") @@ -307,11 +442,19 @@ class LayerCollection(object): return math_ops.add_n( tuple(loss.evaluate_on_sample() for loss in self.losses)) + def _get_linked_approx(self, params): + """If params were linked, return their specified approximation.""" + params_set = frozenset(ensure_sequence(params)) + if params_set in self.linked_parameters: + return self.linked_parameters[params_set] + else: + return None + def register_fully_connected(self, params, inputs, outputs, - approx=APPROX_KRONECKER_NAME, + approx=None, reuse=VARIABLE_SCOPE): """Registers a fully connnected layer. @@ -332,15 +475,15 @@ class LayerCollection(object): KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - approx_to_block_types = { - APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB, - APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB, - } + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = self.default_fully_connected_approximation - if approx not in approx_to_block_types: + if approx not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES: raise ValueError("Bad value {} for approx.".format(approx)) - block_type = approx_to_block_types[approx] + block_type = _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES[approx] has_bias = isinstance(params, (tuple, list)) block = self.register_block(params, block_type(self, has_bias), reuse=reuse) @@ -352,7 +495,7 @@ class LayerCollection(object): padding, inputs, outputs, - approx=APPROX_KRONECKER_NAME, + approx=None, reuse=VARIABLE_SCOPE): """Registers a convolutional layer. @@ -377,15 +520,16 @@ class LayerCollection(object): KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - approx_to_block_types = { - APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB, - APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, - } - if approx not in approx_to_block_types: + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = self.default_conv2d_approximation + + if approx not in _CONV2D_APPROX_TO_BLOCK_TYPES: raise ValueError("Bad value {} for approx.".format(approx)) - block_type = approx_to_block_types[approx] + block_type = _CONV2D_APPROX_TO_BLOCK_TYPES[approx] block = self.register_block( params, block_type(self, params, strides, padding), reuse=reuse) block.register_additional_minibatch(inputs, outputs) @@ -393,7 +537,7 @@ class LayerCollection(object): def register_generic(self, params, batch_size, - approx=APPROX_DIAGONAL_NAME, + approx=None, reuse=VARIABLE_SCOPE): """Registers a generic layer. @@ -413,15 +557,16 @@ class LayerCollection(object): KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - approx_to_block_types = { - APPROX_FULL_NAME: fb.FullFB, - APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB, - } - if approx not in approx_to_block_types: + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = self.default_generic_approximation + + if approx not in _GENERIC_APPROX_TO_BLOCK_TYPES: raise ValueError("Bad value {} for approx.".format(approx)) - block_type = approx_to_block_types[approx] + block_type = _GENERIC_APPROX_TO_BLOCK_TYPES[approx] block = self.register_block(params, block_type(self, params), reuse=reuse) block.register_additional_minibatch(batch_size) @@ -560,10 +705,10 @@ class LayerCollection(object): try: hash(args) except TypeError: - raise TypeError(( - "Unable to use (cls, args) = ({}, {}) as a key in " - "LayerCollection.fisher_factors. The pair cannot be hashed." - ).format(cls, args)) + raise TypeError( + ("Unable to use (cls, args) = ({}, {}) as a key in " + "LayerCollection.fisher_factors. The pair cannot be hashed.").format( + cls, args)) with variable_scope.variable_scope(self._var_scope): return utils.setdefault(self.fisher_factors, (cls, args), diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index bfa15e0948c..88299e495cb 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -44,7 +44,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): momentum=0., momentum_type="regular", norm_constraint=None, - name="KFAC",): + name="KFAC", + estimation_mode="gradients"): """Initializes the KFAC optimizer with the given settings. Args: @@ -72,6 +73,10 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): specified value. May only be used with momentum type 'regular'. (Default: None) name: The name for this optimizer. (Default: 'KFAC') + estimation_mode: The type of estimator to use for the Fishers. Can be + 'gradients', 'empirical', 'curvature_propagation', or 'exact'. + (Default: 'gradients'). See the doc-string for FisherEstimator for + more a more detailed description of these options. Raises: ValueError: If the momentum type is unsupported. @@ -86,7 +91,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): variables = tf_variables.trainable_variables() self._fisher_est = est.FisherEstimator(variables, cov_ema_decay, damping, - layer_collection) + layer_collection, + estimation_mode=estimation_mode) momentum_type = momentum_type.lower() legal_momentum_types = ["regular", "adam", "qmodel"] diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 160d9eb3034..30630852181 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -1403,7 +1403,8 @@ def dropout(inputs, noise_shape=None, is_training=True, outputs_collections=None, - scope=None): + scope=None, + seed=None): """Returns a dropout op applied to the input. With probability `keep_prob`, outputs the input element scaled up by @@ -1421,6 +1422,8 @@ def dropout(inputs, Otherwise, inputs is returned. outputs_collections: Collection to add the outputs. scope: Optional scope for name_scope. + seed: A Python integer. Used to create random seeds. See + @{tf.set_random_seed} for behavior. Returns: A tensor representing the output of the operation. @@ -1430,6 +1433,7 @@ def dropout(inputs, inputs = ops.convert_to_tensor(inputs) layer = core_layers.Dropout(rate=1 - keep_prob, noise_shape=noise_shape, + seed=seed, name=sc.name, _scope=sc) outputs = layer.apply(inputs, training=is_training) diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index f2406205f38..9019d3a6099 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1345,11 +1345,20 @@ class DropoutTest(test.TestCase): num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0)) output = _layers.dropout(images) num_elem = math_ops.reduce_mean(math_ops.to_float(output > 0)) - sess.run(variables_lib.global_variables_initializer()) num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial]) self.assertLess(num_elem, num_elem_initial / 2 + 0.1) self.assertGreater(num_elem, num_elem_initial / 2 - 0.1) + def testDropoutSeed(self): + """Test that providing the same seed produces the same result.""" + height, width = 10, 10 + with self.test_session() as sess: + images = random_ops.random_uniform( + (5, height, width, 3), seed=1, name='images') + output1 = _layers.dropout(images, seed=1) + output2 = _layers.dropout(images, seed=1) + self.assertAllEqual(*sess.run([output1, output2])) + def testCreateDropoutNoTraining(self): height, width = 3, 3 with self.test_session() as sess: @@ -1358,7 +1367,6 @@ class DropoutTest(test.TestCase): num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0)) output = _layers.dropout(images, is_training=False) num_elem = math_ops.reduce_mean(math_ops.to_float(output > 0)) - sess.run(variables_lib.global_variables_initializer()) num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial]) self.assertEqual(num_elem, num_elem_initial) outputs, inputs = sess.run([output, images]) diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index 49413092a6b..6ffd2a13399 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -33,6 +33,7 @@ from __future__ import division from __future__ import print_function import os +import tempfile import time from tensorflow.contrib.layers.python.layers import feature_column @@ -644,18 +645,22 @@ def make_best_model_export_strategy(serving_input_fn, # TODO(b/67013778): Revisit this approach when corresponding changes to # TF Core are finalized. -def extend_export_strategy(base_export_strategy, post_export_fn, - post_export_name): +def extend_export_strategy(base_export_strategy, + post_export_fn, + post_export_name=None): """Extend ExportStrategy, calling post_export_fn after export. Args: base_export_strategy: An ExportStrategy that can be passed to the Experiment constructor. post_export_fn: A user-specified function to call after exporting the - SavedModel. Takes the export directory as an argument, and returns - a string path to a (potentially different) SavedModel. + SavedModel. Takes two arguments - the path to the SavedModel exported by + base_export_strategy and the directory where to export the SavedModel + modified by the post_export_fn. Returns the path to the exported + SavedModel. post_export_name: The directory name under the export base directory where - SavedModels generated by the post_export_fn will be written. + SavedModels generated by the post_export_fn will be written. If None, the + directory name of base_export_strategy is used. Returns: An ExportStrategy that can be passed to the Experiment constructor. @@ -675,12 +680,24 @@ def extend_export_strategy(base_export_strategy, post_export_fn, Raises: ValueError: If `estimator` is a ${tf.estimator.Estimator} instance - and `default_output_alternative_key` was specified. + and `default_output_alternative_key` was specified or if post_export_fn + does not return a valid directory. """ - export_dir = base_export_strategy.export(estimator, export_dir_base, - checkpoint_path) - if post_export_fn: - export_dir = post_export_fn(export_dir) - return export_dir + tmp_base_export_dir = tempfile.mkdtemp() + tmp_base_export = base_export_strategy.export( + estimator, tmp_base_export_dir, checkpoint_path) + tmp_post_export_dir = tempfile.mkdtemp() + tmp_post_export = post_export_fn(tmp_base_export, tmp_post_export_dir) - return export_strategy.ExportStrategy(post_export_name, export_fn) + if not tmp_post_export.startswith(tmp_post_export_dir): + raise ValueError('post_export_fn must return a sub-directory of {}' + .format(tmp_post_export_dir)) + export_relpath = os.path.relpath(tmp_post_export, tmp_post_export_dir) + + gfile.Rename( + os.path.join(tmp_post_export_dir, export_relpath), + os.path.join(export_dir_base, export_relpath)) + return os.path.join(export_dir_base, export_relpath) + + name = post_export_name if post_export_name else base_export_strategy.name + return export_strategy.ExportStrategy(name, export_fn) diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py index 27f17b54221..ec3a88003f0 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py @@ -743,12 +743,19 @@ class SavedModelExportUtilsTest(test.TestCase): None) def test_extend_export_strategy(self): - def _base_export_fn(unused_estimator, export_dir_base, - unused_checkpoint_path=None): - return export_dir_base + "/e1" - def _post_export_fn(orig_path): - return orig_path + "/rewrite" + def _base_export_fn(unused_estimator, + export_dir_base, + unused_checkpoint_path=None): + base_path = os.path.join(export_dir_base, "e1") + gfile.MkDir(base_path) + return base_path + + def _post_export_fn(orig_path, new_path): + assert orig_path.endswith("/e1") + post_export_path = os.path.join(new_path, "rewrite") + gfile.MkDir(post_export_path) + return post_export_path base_export_strategy = export_strategy_lib.ExportStrategy( "Servo", _base_export_fn) @@ -758,9 +765,67 @@ class SavedModelExportUtilsTest(test.TestCase): self.assertEqual(final_export_strategy.name, "Servo2") test_estimator = TestEstimator() - final_path = final_export_strategy.export(test_estimator, "/path/to/orig", - "/path/to/checkpoint") - self.assertEqual("/path/to/orig/e1/rewrite", final_path) + tmpdir = tempfile.mkdtemp() + final_path = final_export_strategy.export(test_estimator, tmpdir, + os.path.join( + tmpdir, "checkpoint")) + self.assertEqual(os.path.join(tmpdir, "rewrite"), final_path) + + def test_extend_export_strategy_same_name(self): + + def _base_export_fn(unused_estimator, + export_dir_base, + unused_checkpoint_path=None): + base_path = os.path.join(export_dir_base, "e1") + gfile.MkDir(base_path) + return base_path + + def _post_export_fn(orig_path, new_path): + assert orig_path.endswith("/e1") + post_export_path = os.path.join(new_path, "rewrite") + gfile.MkDir(post_export_path) + return post_export_path + + base_export_strategy = export_strategy_lib.ExportStrategy( + "Servo", _base_export_fn) + + final_export_strategy = saved_model_export_utils.extend_export_strategy( + base_export_strategy, _post_export_fn) + self.assertEqual(final_export_strategy.name, "Servo") + + test_estimator = TestEstimator() + tmpdir = tempfile.mkdtemp() + final_path = final_export_strategy.export(test_estimator, tmpdir, + os.path.join( + tmpdir, "checkpoint")) + self.assertEqual(os.path.join(tmpdir, "rewrite"), final_path) + + def test_extend_export_strategy_raises_error(self): + + def _base_export_fn(unused_estimator, + export_dir_base, + unused_checkpoint_path=None): + base_path = os.path.join(export_dir_base, "e1") + gfile.MkDir(base_path) + return base_path + + def _post_export_fn(unused_orig_path, unused_new_path): + return tempfile.mkdtemp() + + base_export_strategy = export_strategy_lib.ExportStrategy( + "Servo", _base_export_fn) + + final_export_strategy = saved_model_export_utils.extend_export_strategy( + base_export_strategy, _post_export_fn) + + test_estimator = TestEstimator() + tmpdir = tempfile.mkdtemp() + with self.assertRaises(ValueError) as ve: + final_export_strategy.export(test_estimator, tmpdir, + os.path.join(tmpdir, "checkpoint")) + + self.assertTrue( + "post_export_fn must return a sub-directory" in str(ve.exception)) def _create_test_export_dir(export_dir_base): diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md index feb35c850eb..827c5d0baa9 100644 --- a/tensorflow/contrib/lite/README.md +++ b/tensorflow/contrib/lite/README.md @@ -48,9 +48,8 @@ NOTE: Bazel does not currently support building for Android on Windows. Full sup ### Install Android NDK and SDK Bazel is the primary build system for TensorFlow. Bazel and the Android NDK and SDK must be installed on your system. - Install the latest version of Bazel as per the instructions on the [Bazel website](https://bazel.build/versions/master/docs/install.html) - - The Android NDK is required to build the native (C/C++) TensorFlow code. The current recommended version is 14b, which may be found [here](https://developer.android.com/tools/revisions/build-tools.html). - - The Android SDK and build tools may be obtained [here](https://developer.android.com/tools/revisions/build-tools.html), or alternatively as part of [Android Studio](https://developer.android.com/studio/index.html). Build tools API >= 23 is required to build the TensorFlow Android demo (though it will run on API >= 21 devices). - + - The Android NDK is required to build the native (C/C++) TensorFlow Lite code. The current recommended version is 14b, which can be found [here](https://developer.android.com/ndk/downloads/older_releases.html#ndk-14b-downloads). + - The Android SDK and build tools may be obtained [here](https://developer.android.com/tools/revisions/build-tools.html), or alternatively as part of [Android Studio](https://developer.android.com/studio/index.html). Build tools API >= 23 is required to build the TF Android demo (though it will run on API >= 21 devices). - In the root of the TensorFlow repository update the `WORKSPACE` file with the `api_level` and location of the SDK and NDK. If you installed it with AndroidStudio the SDK path can be found in the SDK manager, and the default NDK path is:`{SDK path}/ndk-bundle.` ``` @@ -147,7 +146,7 @@ bazel-bin/tensorflow/python/tools/freeze_graph\ ``` The user has to first build the freeze_graph script using bazel and then run the script. The input_binary flag has to be enabled to ensure that the protobuf is read and written in binary format. The user has to input the .pb and the .ckpt files to freeze the graph The output_node_names may not be obvious outside of the code that built the model. The easiest way to find them is to visualize the graph, either with -graphviz, or in tensorboard. +graphviz, or [in tensorboard](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2/#3). This frozen Graphdef is now ready to be converted to flatbuffer format (.lite) for use on Android or iOS. On Android users have the flexibility to use either the float or quantized versions of the frozen graphdef, if available, using the Tensorflow Optimizing Converter tool. @@ -166,11 +165,11 @@ bazel run --config=opt tensorflow/contrib/lite/toco:toco -- \ - The input_file argument should point to the frozen GraphDef file that holds the model architecture. - The output_file argument should point to where the TensorFlow Lite model file should be generated. -- The input_type and inference_type arguments should be set to FLOAT, unless converting a quantized model. -- Setting the input_array, output_array and input_shape arguments are a bit trickier. The easiest way to find these values is to explore the graph in TensorBoard. The user should reuse the arguments that were used for specifying the output nodes for inference in the `freeze_graph`step. +- The input_type and inference_type arguments should be set to FLOAT, unless converted a [quantized](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/) model. +- Setting the input_array, output_array and input_shape arguments are a bit trickier. The easiest way to find these values is to explore the graph in tensorboard . The user should reuse the arguments that were used for specifying the output nodes for inference in the `freeze_graph`step. Note, it is also possible to use the Tensorflow Optimizing Converter through protos either from Python or from the command line see the -documentation [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/README.md). A developer can then integrate the conversion step into their model design workflow to ensure that a model will be easily convertible to a mobile inference graph. For example, +documentation [here](https://github.com/tensorflow/tensorflow/tree/mastertensorflow/contrib/lite/python:toco_from_protos target) A developer can then integrate the conversion step into their model design workflow to ensure that a model will be easily convertible to a mobile inference graph. For example, ``` import tensorflow as tf @@ -199,4 +198,4 @@ The [demo app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/c Note that you’d need to follow instructions for installing TensorFlow on Android, setting up bazel and Android Studio outlined [here](https://www.tensorflow.org/mobile/android_build). ### For iOS -Follow the documentation [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/ios.md) to integrate a TFLite model into your app. +Follow the documentation [here](https://github.com/TensorFlow/TensorFlow/blob/master/TensorFlow/contrib/lite/g3doc/ios.md) to get integrate a TFLite model into your app. diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc index dd414d53bd3..72f705fe424 100644 --- a/tensorflow/contrib/lite/kernels/svdf.cc +++ b/tensorflow/contrib/lite/kernels/svdf.cc @@ -183,8 +183,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // Reduction sum - // TODO(ghodrat): Consider not reusing state for the temporary output, this - // way ReductionSum operates on row-vector instead of column vector. for (int b = 0; b < batch_size; b++) { float* output_ptr_batch = output->data.f + b * num_units; float* scratch_ptr_batch = scratch->data.f + b * num_filters; diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index f8208f6f98c..e2f3560e61b 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -30,6 +30,17 @@ limitations under the License. namespace tflite { +namespace { +inline const tflite::Model* VerifyAndGetModel(const void* buf, size_t len) { + ::flatbuffers::Verifier verifier(static_cast(buf), len); + if (VerifyModelBuffer(verifier)) { + return ::tflite::GetModel(buf); + } else { + return nullptr; + } +} +} // namespace + const char* kEmptyTensorName = ""; std::unique_ptr FlatBufferModel::BuildFromFile( @@ -64,7 +75,7 @@ FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file, if (!allocation_->valid()) return; if (!CheckModelIdentifier()) return; - model_ = ::tflite::GetModel(allocation_->base()); + model_ = VerifyAndGetModel(allocation_->base(), allocation_->bytes()); } bool FlatBufferModel::CheckModelIdentifier() const { @@ -84,7 +95,8 @@ FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes, : DefaultErrorReporter()) { allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter); if (!allocation_->valid()) return; - model_ = ::tflite::GetModel(allocation_->base()); + + model_ = VerifyAndGetModel(allocation_->base(), allocation_->bytes()); } FlatBufferModel::~FlatBufferModel() { delete allocation_; } diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc index ae823650d6d..61043866420 100644 --- a/tensorflow/contrib/lite/model_test.cc +++ b/tensorflow/contrib/lite/model_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/contrib/lite/model.h" @@ -245,6 +246,14 @@ TEST(BasicFlatBufferModel, TestNullErrorReporter) { ASSERT_NE(interpreter->Invoke(), kTfLiteOk); } +// Test what happens if we cannot bind any of the ops. +TEST(BasicFlatBufferModel, TestBuildModelFromCorruptedData) { + std::string corrupted_data = "123"; + auto model = FlatBufferModel::BuildFromBuffer(corrupted_data.c_str(), + corrupted_data.length()); + ASSERT_FALSE(model); +} + // TODO(aselle): Add tests for serialization of builtin op data types. // These tests will occur with the evaluation tests of individual operators, // not here. diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/README.md b/tensorflow/contrib/lite/models/testdata/g3doc/README.md index d0cdae6bdfc..77fe8b3f84f 100644 --- a/tensorflow/contrib/lite/models/testdata/g3doc/README.md +++ b/tensorflow/contrib/lite/models/testdata/g3doc/README.md @@ -91,3 +91,10 @@ same input. [TTS model test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_tts_model_test.cc) [ASR AM model test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc) + +## Android Support +The models have been tested on Android phones, using the following tests: + +[Hotword] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/android/BUILD?rcl=172930882&l=25) + +[Speaker-id] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/android/BUILD?rcl=172930882&l=36) diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 77d381c1c5d..eb08b5d1e54 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -75,7 +75,6 @@ cc_library( ":runtime", ":toco_port", "//tensorflow/core:lib", - "@protobuf_archive//:protobuf_headers", ], ) @@ -88,9 +87,6 @@ cc_library( "toco_graphviz_dump_options.h", ], visibility = ["//visibility:public"], - deps = [ - "@com_google_absl//absl/strings", - ], ) cc_library( diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index d44b5dc7b02..9cb26c8752c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -143,7 +143,7 @@ bool HardcodeMinMaxForAverageOrMaxPool(Model* model, Operator* op) { return true; } -bool HardcodeMinMaxForReshape(Model* model, Operator* op) { +bool HardcodeMinMaxForReshapeOrSqueeze(Model* model, Operator* op) { auto& output_array = model->GetArray(op->outputs[0]); if (output_array.minmax) { return false; @@ -201,8 +201,9 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { changed = HardcodeMinMaxForAverageOrMaxPool(model, op); break; + case OperatorType::kSqueeze: case OperatorType::kTensorFlowReshape: - changed = HardcodeMinMaxForReshape(model, op); + changed = HardcodeMinMaxForReshapeOrSqueeze(model, op); break; case OperatorType::kLogistic: diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index 5551755ea7f..d33597d3814 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -42,6 +42,7 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kL2Normalization || type == OperatorType::kAdd || type == OperatorType::kAveragePool || type == OperatorType::kMaxPool || type == OperatorType::kLogistic || type == OperatorType::kSoftmax || + type == OperatorType::kSqueeze || type == OperatorType::kTensorFlowReshape || type == OperatorType::kMul || type == OperatorType::kSpaceToDepth || type == OperatorType::kDepthToSpace; diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 24692ff12fb..6e2190cb7af 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -60,61 +60,6 @@ def _safe_div(numerator, denominator, name): name=name) -# TODO(ptucker): Move this somewhere common, to share with ops/losses/losses.py. -def _assert_weights_rank(weights, values): - """`weights` rank must be either `0`, or the same as 'values'.""" - return check_ops.assert_rank_in(weights, (0, array_ops.rank(values))) - - -def _count_condition(values, - weights=None, - metrics_collections=None, - updates_collections=None): - """Sums the weights of cases where the given values are True. - - If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. - - Args: - values: A `bool` `Tensor` of arbitrary size. - weights: Optional `Tensor` whose rank is either 0, or the same rank as - `values`, and must be broadcastable to `values` (i.e., all dimensions - must be either `1`, or the same as the corresponding `values` - dimension). - metrics_collections: An optional list of collections that the metric - value variable should be added to. - updates_collections: An optional list of collections that the metric update - ops should be added to. - - Returns: - value_tensor: A `Tensor` representing the current value of the metric. - update_op: An operation that accumulates the error from a batch of data. - - Raises: - ValueError: If `weights` is not `None` and its shape doesn't match `values`, - or if either `metrics_collections` or `updates_collections` are not a list - or tuple. - """ - check_ops.assert_type(values, dtypes.bool) - count_ = metrics_impl.metric_variable([], dtypes.float32, name='count') - - values = math_ops.to_float(values) - if weights is not None: - weights = math_ops.to_float(weights) - with ops.control_dependencies((_assert_weights_rank(weights, values),)): - values = math_ops.multiply(values, weights) - - value_tensor = array_ops.identity(count_) - update_op = state_ops.assign_add(count_, math_ops.reduce_sum(values)) - - if metrics_collections: - ops.add_to_collections(metrics_collections, value_tensor) - - if updates_collections: - ops.add_to_collections(updates_collections, update_op) - - return value_tensor, update_op - - def streaming_true_positives(predictions, labels, weights=None, @@ -194,17 +139,13 @@ def streaming_true_negatives(predictions, either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope(name, 'true_negatives', - (predictions, labels, weights)): - - predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access - predictions=math_ops.cast(predictions, dtype=dtypes.bool), - labels=math_ops.cast(labels, dtype=dtypes.bool), - weights=weights) - is_true_negative = math_ops.logical_and( - math_ops.equal(labels, False), math_ops.equal(predictions, False)) - return _count_condition(is_true_negative, weights, metrics_collections, - updates_collections) + return metrics.true_negatives( + predictions=predictions, + labels=labels, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) def streaming_false_positives(predictions, @@ -294,34 +235,6 @@ def streaming_false_negatives(predictions, name=name) -# TODO(ptucker): Move this somewhere common, to share with ops/losses/losses.py. -def _broadcast_weights(weights, values): - """Broadcast `weights` to the same shape as `values`. - - This returns a version of `weights` following the same broadcast rules as - `mul(weights, values)`. When computing a weighted average, use this function - to broadcast `weights` before summing them; e.g., - `reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`. - - Args: - weights: `Tensor` whose rank is either 0, or the same rank as `values`, and - must be broadcastable to `values` (i.e., all dimensions must be either - `1`, or the same as the corresponding `values` dimension). - values: `Tensor` of any shape. - - Returns: - `weights` broadcast to `values` shape. - """ - with ops.name_scope(None, 'broadcast_weights', (values, weights)) as scope: - weights_shape = weights.get_shape() - values_shape = values.get_shape() - if (weights_shape.is_fully_defined() and values_shape.is_fully_defined() and - weights_shape.is_compatible_with(values_shape)): - return weights - with ops.control_dependencies((_assert_weights_rank(weights, values),)): - return math_ops.multiply(weights, array_ops.ones_like(values), name=scope) - - def streaming_mean(values, weights=None, metrics_collections=None, @@ -423,8 +336,10 @@ def streaming_mean_tensor(values, updates_collections=updates_collections, name=name) -@deprecated(None, "Please switch to tf.metrics.accuracy. Note that the order " - "of the inputs of labels and predictions have been switched.") + +@deprecated( + None, 'Please switch to tf.metrics.accuracy. Note that the order of the ' + 'labels and predictions arguments has been switched.') def streaming_accuracy(predictions, labels, weights=None, @@ -592,53 +507,6 @@ def streaming_recall(predictions, name=name) -def _true_negatives(labels, - predictions, - weights=None, - metrics_collections=None, - updates_collections=None, - name=None): - """Sum the weights of true negatives. - - If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. - - Args: - labels: The ground truth values, a `Tensor` whose dimensions must match - `predictions`. Will be cast to `bool`. - predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will - be cast to `bool`. - weights: Optional `Tensor` whose rank is either 0, or the same rank as - `labels`, and must be broadcastable to `labels` (i.e., all dimensions must - be either `1`, or the same as the corresponding `labels` dimension). - metrics_collections: An optional list of collections that the metric - value variable should be added to. - updates_collections: An optional list of collections that the metric update - ops should be added to. - name: An optional variable_scope name. - - Returns: - value_tensor: A `Tensor` representing the current value of the metric. - update_op: An operation that accumulates the error from a batch of data. - - Raises: - ValueError: If `predictions` and `labels` have mismatched shapes, or if - `weights` is not `None` and its shape doesn't match `predictions`, or if - either `metrics_collections` or `updates_collections` are not a list or - tuple. - """ - with variable_scope.variable_scope(name, 'true_negatives', - (predictions, labels, weights)): - - predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access - predictions=math_ops.cast(predictions, dtype=dtypes.bool), - labels=math_ops.cast(labels, dtype=dtypes.bool), - weights=weights) - is_true_negative = math_ops.logical_and( - math_ops.equal(labels, False), math_ops.equal(predictions, False)) - return _count_condition(is_true_negative, weights, metrics_collections, - updates_collections) - - def streaming_false_positive_rate(predictions, labels, weights=None, @@ -696,16 +564,16 @@ def streaming_false_positive_rate(predictions, weights=weights) false_p, false_positives_update_op = metrics.false_positives( - labels, - predictions, - weights, + labels=labels, + predictions=predictions, + weights=weights, metrics_collections=None, updates_collections=None, name=None) - true_n, true_negatives_update_op = _true_negatives( - labels, - predictions, - weights, + true_n, true_negatives_update_op = metrics.true_negatives( + labels=labels, + predictions=predictions, + weights=weights, metrics_collections=None, updates_collections=None, name=None) @@ -1102,8 +970,10 @@ def streaming_curve_points(labels=None, return points, update_op -@deprecated(None, "Please switch to tf.metrics.auc. Note that the order of " - "the inputs of labels and predictions have been switched.") + +@deprecated( + None, 'Please switch to tf.metrics.auc. Note that the order of the ' + 'labels and predictions arguments has been switched.') def streaming_auc(predictions, labels, weights=None, @@ -1636,9 +1506,10 @@ def streaming_sensitivity_at_specificity(predictions, updates_collections=updates_collections, name=name) + @deprecated( - None, "Please switch to tf.metrics.precision_at_thresholds. Note that the " - "order of of the inputs of labels and predictions have been switched.") + None, 'Please switch to tf.metrics.precision_at_thresholds. Note that the ' + 'order of the labels and predictions arguments has been switched.') def streaming_precision_at_thresholds(predictions, labels, thresholds, @@ -1697,9 +1568,10 @@ def streaming_precision_at_thresholds(predictions, updates_collections=updates_collections, name=name) + @deprecated( - None, "Please switch to tf.metrics.recall_at_thresholds. Note that the " - "order of of the inputs of labels and predictions have been switched.") + None, 'Please switch to tf.metrics.recall_at_thresholds. Note that the ' + 'order of the labels and predictions arguments has been switched.') def streaming_recall_at_thresholds(predictions, labels, thresholds, @@ -1909,8 +1781,8 @@ def _at_k_name(name, k=None, class_id=None): return name -@deprecated("2016-11-08", "Please use `streaming_sparse_recall_at_k`, " - "and reshape labels from [batch_size] to [batch_size, 1].") +@deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, ' + 'and reshape labels from [batch_size] to [batch_size, 1].') def streaming_recall_at_k(predictions, labels, k, @@ -2543,7 +2415,8 @@ def streaming_sparse_average_precision_at_top_k(top_k_predictions, updates_collections=updates_collections, name=name) -@deprecated(None, "Please switch to tf.metrics.mean.") + +@deprecated(None, 'Please switch to tf.metrics.mean.') def streaming_mean_absolute_error(predictions, labels, weights=None, diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD index d1beafcb28d..3892654f257 100644 --- a/tensorflow/contrib/summary/BUILD +++ b/tensorflow/contrib/summary/BUILD @@ -25,13 +25,12 @@ py_test( srcs_version = "PY2AND3", deps = [ ":summary_ops", + ":summary_test_internal", ":summary_test_util", "//tensorflow/python:array_ops", "//tensorflow/python:errors", "//tensorflow/python:framework", "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:ops", "//tensorflow/python:platform", "//tensorflow/python:state_ops", "//tensorflow/python:training", @@ -41,6 +40,20 @@ py_test( ], ) +py_test( + name = "summary_ops_graph_test", + srcs = ["summary_ops_graph_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":summary_ops", + ":summary_test_internal", + "//tensorflow/python:client_testlib", + "//tensorflow/python:ops", + "//tensorflow/python:platform", + "//tensorflow/python:training", + ], +) + py_library( name = "summary_ops", srcs = ["summary_ops.py"], @@ -98,3 +111,15 @@ py_library( "//tensorflow/python:platform", ], ) + +py_library( + name = "summary_test_internal", + testonly = 1, + srcs = ["summary_test_internal.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow/python:lib", + "//tensorflow/python:platform", + ], +) diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py index 813e8b2b09d..f783179f614 100644 --- a/tensorflow/contrib/summary/summary.py +++ b/tensorflow/contrib/summary/summary.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""TensorFlow Summary API v2. -"""Contrib summary package. - -The operations in this package are safe to use with eager execution turned or on -off. - +The operations in this package are safe to use with eager execution turned on or +off. It has a more flexible API that allows summaries to be written directly +from ops to places other than event log files, rather than propagating protos +from @{tf.summary.merge_all} to @{tf.summary.FileWriter}. """ from __future__ import absolute_import @@ -32,11 +32,14 @@ from tensorflow.contrib.summary.summary_ops import create_summary_db_writer from tensorflow.contrib.summary.summary_ops import create_summary_file_writer from tensorflow.contrib.summary.summary_ops import eval_dir from tensorflow.contrib.summary.summary_ops import generic +from tensorflow.contrib.summary.summary_ops import graph from tensorflow.contrib.summary.summary_ops import histogram from tensorflow.contrib.summary.summary_ops import image from tensorflow.contrib.summary.summary_ops import import_event +from tensorflow.contrib.summary.summary_ops import initialize from tensorflow.contrib.summary.summary_ops import never_record_summaries from tensorflow.contrib.summary.summary_ops import record_summaries_every_n_global_steps from tensorflow.contrib.summary.summary_ops import scalar from tensorflow.contrib.summary.summary_ops import should_record_summaries from tensorflow.contrib.summary.summary_ops import summary_writer_initializer_op +from tensorflow.contrib.summary.summary_ops import SummaryWriter diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index f6be99f6ae8..a72c0c80aab 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -27,6 +27,7 @@ import time import six from tensorflow.contrib.summary import gen_summary_ops +from tensorflow.core.framework import graph_pb2 from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -99,25 +100,32 @@ def never_record_summaries(): class SummaryWriter(object): - """Encapsulates a summary writer.""" + """Encapsulates a stateful summary writer resource. - def __init__(self, resource): + See also: + - @{tf.contrib.summary.create_summary_file_writer} + - @{tf.contrib.summary.create_summary_db_writer} + """ + + def __init__(self, resource): self._resource = resource if context.in_eager_mode(): self._resource_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._resource, handle_device="cpu:0") def set_as_default(self): + """Enables this summary writer for the current thread.""" context.context().summary_writer_resource = self._resource @tf_contextlib.contextmanager def as_default(self): + """Enables summary writing within a `with` block.""" if self._resource is None: - yield + yield self else: old = context.context().summary_writer_resource context.context().summary_writer_resource = self._resource - yield + yield self # Flushes the summary writer in eager mode or in graph functions, but not # in legacy graph mode (you're on your own there). with ops.device("cpu:0"): @@ -125,6 +133,43 @@ class SummaryWriter(object): context.context().summary_writer_resource = old +def initialize( + graph=None, # pylint: disable=redefined-outer-name + session=None): + """Initializes summary writing for graph execution mode. + + This helper method provides a higher-level alternative to using + @{tf.contrib.summary.summary_writer_initializer_op} and + @{tf.contrib.summary.graph}. + + Most users will also want to call @{tf.train.create_global_step} + which can happen before or after this function is called. + + Args: + graph: A @{tf.Graph} or @{tf.GraphDef} to output to the writer. + This function will not write the default graph by default. When + writing to an event log file, the associated step will be zero. + session: So this method can call @{tf.Session.run}. This defaults + to @{tf.get_default_session}. + + Raises: + RuntimeError: If in eager mode, or if the current thread has no + default @{tf.contrib.summary.SummaryWriter}. + ValueError: If session wasn't passed and no default session. + """ + if context.context().summary_writer_resource is None: + raise RuntimeError("No default tf.contrib.summary.SummaryWriter found") + if session is None: + session = ops.get_default_session() + if session is None: + raise ValueError("session must be passed if no default session exists") + session.run(summary_writer_initializer_op()) + if graph is not None: + data = _serialize_graph(graph) + x = array_ops.placeholder(dtypes.string) + session.run(_graph(x, 0), feed_dict={x: data}) + + def create_summary_file_writer(logdir, max_queue=None, flush_millis=None, @@ -192,10 +237,10 @@ def create_summary_db_writer(db_uri, Experiment will not be associated with a User. Must be valid as both a DNS label and Linux username. name: Shared name for this SummaryWriter resource stored to default - Graph. + @{tf.Graph}. Returns: - A new SummaryWriter instance. + A @{tf.contrib.summary.SummaryWriter} instance. """ with ops.device("cpu:0"): if experiment_name is None: @@ -240,7 +285,16 @@ def _nothing(): def all_summary_ops(): - """Graph-mode only. Returns all summary ops.""" + """Graph-mode only. Returns all summary ops. + + Please note this excludes @{tf.contrib.summary.graph} ops. + + Returns: + The summary ops. + + Raises: + RuntimeError: If in Eager mode. + """ if context.in_eager_mode(): raise RuntimeError( "tf.contrib.summary.all_summary_ops is only supported in graph mode.") @@ -248,7 +302,14 @@ def all_summary_ops(): def summary_writer_initializer_op(): - """Graph-mode only. Returns the list of ops to create all summary writers.""" + """Graph-mode only. Returns the list of ops to create all summary writers. + + Returns: + The initializer ops. + + Raises: + RuntimeError: If in Eager mode. + """ if context.in_eager_mode(): raise RuntimeError( "tf.contrib.summary.summary_writer_initializer_op is only " @@ -367,21 +428,72 @@ def audio(name, tensor, sample_rate, max_outputs, family=None, return summary_writer_function(name, tensor, function, family=family) -def import_event(tensor, name=None): - """Writes a tf.Event binary proto. +def graph(param, step=None, name=None): + """Writes a TensorFlow graph to the summary interface. - When using create_summary_db_writer(), this can be used alongside - tf.TFRecordReader to load event logs into the database. Please note - that this is lower level than the other summary functions and will - ignore any conditions set by methods like should_record_summaries(). + The graph summary is, strictly speaking, not a summary. Conditions + like @{tf.contrib.summary.never_record_summaries} do not apply. Only + a single graph can be associated with a particular run. If multiple + graphs are written, then only the last one will be considered by + TensorBoard. + + When not using eager execution mode, the user should consider passing + the `graph` parameter to @{tf.contrib.summary.initialize} instead of + calling this function. Otherwise special care needs to be taken when + using the graph to record the graph. Args: - tensor: A `Tensor` of type `string` containing a serialized `Event` - proto. + param: A @{tf.Tensor} containing a serialized graph proto. When + eager execution is enabled, this function will automatically + coerce @{tf.Graph}, @{tf.GraphDef}, and string types. + step: The global step variable. This doesn't have useful semantics + for graph summaries, but is used anyway, due to the structure of + event log files. This defaults to the global step. name: A name for the operation (optional). Returns: - The created Operation. + The created @{tf.Operation} or a @{tf.no_op} if summary writing has + not been enabled for this context. + + Raises: + TypeError: If `param` isn't already a @{tf.Tensor} in graph mode. + """ + if not context.in_eager_mode() and not isinstance(param, ops.Tensor): + raise TypeError("graph() needs a tf.Tensor (e.g. tf.placeholder) in graph " + "mode, but was: %s" % type(param)) + writer = context.context().summary_writer_resource + if writer is None: + return control_flow_ops.no_op() + with ops.device("cpu:0"): + if step is None: + step = training_util.get_global_step() + else: + step = ops.convert_to_tensor(step, dtypes.int64) + if isinstance(param, (ops.Graph, graph_pb2.GraphDef)): + tensor = ops.convert_to_tensor(_serialize_graph(param), dtypes.string) + else: + tensor = array_ops.identity(param) + return gen_summary_ops.write_graph_summary(writer, step, tensor, name=name) + +_graph = graph # for functions with a graph parameter + + +def import_event(tensor, name=None): + """Writes a @{tf.Event} binary proto. + + When using create_summary_db_writer(), this can be used alongside + @{tf.TFRecordReader} to load event logs into the database. Please + note that this is lower level than the other summary functions and + will ignore any conditions set by methods like + @{tf.contrib.summary.should_record_summaries}. + + Args: + tensor: A @{tf.Tensor} of type `string` containing a serialized + @{tf.Event} proto. + name: A name for the operation (optional). + + Returns: + The created @{tf.Operation}. """ return gen_summary_ops.import_event( context.context().summary_writer_resource, tensor, name=name) @@ -390,3 +502,10 @@ def import_event(tensor, name=None): def eval_dir(model_dir, name=None): """Construct a logdir for an eval summary writer.""" return os.path.join(model_dir, "eval" if not name else "eval_" + name) + + +def _serialize_graph(arbitrary_graph): + if isinstance(arbitrary_graph, ops.Graph): + return arbitrary_graph.as_graph_def(add_shapes=True).SerializeToString() + else: + return arbitrary_graph.SerializeToString() diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py new file mode 100644 index 00000000000..8f85f67a258 --- /dev/null +++ b/tensorflow/contrib/summary/summary_ops_graph_test.py @@ -0,0 +1,52 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.summary import summary_ops +from tensorflow.contrib.summary import summary_test_internal +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.training import training_util + +get_all = summary_test_internal.get_all + + +class DbTest(summary_test_internal.SummaryDbTest): + + def testGraphPassedToGraph_isForbiddenForThineOwnSafety(self): + with self.assertRaises(TypeError): + summary_ops.graph(ops.Graph()) + with self.assertRaises(TypeError): + summary_ops.graph('') + + def testGraphSummary(self): + training_util.get_or_create_global_step() + name = 'hi' + graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),)) + with self.test_session(): + with self.create_summary_db_writer().as_default(): + summary_ops.initialize(graph=graph) + six.assertCountEqual(self, [name], + get_all(self.db, 'SELECT node_name FROM Nodes')) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 6e1a746815f..c5ca054f77f 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -12,20 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools -import os import tempfile import six -import sqlite3 from tensorflow.contrib.summary import summary_ops +from tensorflow.contrib.summary import summary_test_internal from tensorflow.contrib.summary import summary_test_util +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import dtypes @@ -36,6 +35,9 @@ from tensorflow.python.ops import state_ops from tensorflow.python.platform import gfile from tensorflow.python.training import training_util +get_all = summary_test_internal.get_all +get_one = summary_test_internal.get_one + class TargetTest(test_util.TensorFlowTestCase): @@ -77,7 +79,7 @@ class TargetTest(test_util.TensorFlowTestCase): summary_ops.scalar('scalar', 2.0) write() - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].simple_value, 2.0) @@ -90,7 +92,7 @@ class TargetTest(test_util.TensorFlowTestCase): summary_ops.scalar('scalar', 2.0) - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'scalar') @@ -103,27 +105,12 @@ class TargetTest(test_util.TensorFlowTestCase): summary_ops.scalar('scalar', 2.0, global_step=global_step) - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'scalar') -class DbTest(test_util.TensorFlowTestCase): - - def setUp(self): - self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite') - if os.path.exists(self.db_path): - os.unlink(self.db_path) - self.db = sqlite3.connect(self.db_path) - self.create_summary_db_writer = functools.partial( - summary_ops.create_summary_db_writer, - db_uri=self.db_path, - experiment_name='experiment', - run_name='run', - user_name='user') - - def tearDown(self): - self.db.close() +class DbTest(summary_test_internal.SummaryDbTest): def testIntegerSummaries(self): step = training_util.create_global_step() @@ -186,13 +173,15 @@ class DbTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): self.create_summary_db_writer(user_name='@') - -def get_one(db, q, *p): - return db.execute(q, p).fetchone()[0] - - -def get_all(db, q, *p): - return unroll(db.execute(q, p).fetchall()) + def testGraphSummary(self): + training_util.get_or_create_global_step() + name = 'hi' + graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),)) + with summary_ops.always_record_summaries(): + with self.create_summary_db_writer().as_default(): + summary_ops.graph(graph) + six.assertCountEqual(self, [name], + get_all(self.db, 'SELECT node_name FROM Nodes')) def get_tensor(db, tag_id, step): @@ -205,9 +194,5 @@ def int64(x): return array_ops.constant(x, dtypes.int64) -def unroll(list_of_tuples): - return sum(list_of_tuples, ()) - - if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/summary/summary_test_internal.py b/tensorflow/contrib/summary/summary_test_internal.py new file mode 100644 index 00000000000..54233f2f50b --- /dev/null +++ b/tensorflow/contrib/summary/summary_test_internal.py @@ -0,0 +1,59 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Internal helpers for tests in this directory.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import os +import sqlite3 + +from tensorflow.contrib.summary import summary_ops +from tensorflow.python.framework import test_util + + +class SummaryDbTest(test_util.TensorFlowTestCase): + """Helper for summary database testing.""" + + def setUp(self): + super(SummaryDbTest, self).setUp() + self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite') + if os.path.exists(self.db_path): + os.unlink(self.db_path) + self.db = sqlite3.connect(self.db_path) + self.create_summary_db_writer = functools.partial( + summary_ops.create_summary_db_writer, + db_uri=self.db_path, + experiment_name='experiment', + run_name='run', + user_name='user') + + def tearDown(self): + self.db.close() + super(SummaryDbTest, self).tearDown() + + +def get_one(db, q, *p): + return db.execute(q, p).fetchone()[0] + + +def get_all(db, q, *p): + return unroll(db.execute(q, p).fetchall()) + + +def unroll(list_of_tuples): + return sum(list_of_tuples, ()) diff --git a/tensorflow/contrib/summary/summary_test_util.py b/tensorflow/contrib/summary/summary_test_util.py index 37b546d3ab3..794c5b8bab1 100644 --- a/tensorflow/contrib/summary/summary_test_util.py +++ b/tensorflow/contrib/summary/summary_test_util.py @@ -26,16 +26,37 @@ from tensorflow.python.lib.io import tf_record from tensorflow.python.platform import gfile -def events_from_file(logdir): - """Returns all events in the single eventfile in logdir.""" - assert gfile.Exists(logdir) - files = gfile.ListDirectory(logdir) - assert len(files) == 1, "Found more than one file in logdir: %s" % files - records = list( - tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) +def events_from_file(filepath): + """Returns all events in a single event file. + + Args: + filepath: Path to the event file. + + Returns: + A list of all tf.Event protos in the event file. + """ + records = list(tf_record.tf_record_iterator(filepath)) result = [] for r in records: event = event_pb2.Event() event.ParseFromString(r) result.append(event) return result + + +def events_from_logdir(logdir): + """Returns all events in the single eventfile in logdir. + + Args: + logdir: The directory in which the single event file is sought. + + Returns: + A list of all tf.Event protos from the single event file. + + Raises: + AssertionError: If logdir does not contain exactly one file. + """ + assert gfile.Exists(logdir) + files = gfile.ListDirectory(logdir) + assert len(files) == 1, "Found not exactly one file in logdir: %s" % files + return events_from_file(os.path.join(logdir, files[0])) diff --git a/tensorflow/contrib/tensorboard/db/schema.cc b/tensorflow/contrib/tensorboard/db/schema.cc index 98fff9e0ae4..d63b2c6cc23 100644 --- a/tensorflow/contrib/tensorboard/db/schema.cc +++ b/tensorflow/contrib/tensorboard/db/schema.cc @@ -135,8 +135,7 @@ class SqliteSchema { /// the database. This field will be mutated if the run is /// restarted. /// description: Optional markdown information. - /// graph: Snappy tf.GraphDef proto with node field cleared. That - /// field can be recreated using GraphNodes and NodeDefs. + /// graph_id: ID of associated Graphs row. Status CreateRunsTable() { return Run(R"sql( CREATE TABLE IF NOT EXISTS Runs ( @@ -147,7 +146,7 @@ class SqliteSchema { inserted_time REAL, started_time REAL, description TEXT, - graph BLOB + graph_id INTEGER ) )sql"); } @@ -205,46 +204,78 @@ class SqliteSchema { )sql"); } - /// \brief Creates NodeDefs table. - /// - /// This table stores NodeDef protos which define the GraphDef for a - /// Run. This functions like a hash table so rows can be shared by - /// multiple Runs in an Experiment. + /// \brief Creates Graphs table. /// /// Fields: /// rowid: Ephemeral b-tree ID dictating locality. - /// experiment_id: Optional int64 for grouping rows. - /// node_def_id: Permanent >0 unique ID. - /// fingerprint: Optional farmhash::Fingerprint64() of uncompressed - /// node_def bytes, coerced to int64. - /// node_def: BLOB containing a Snappy tf.NodeDef proto. - Status CreateNodeDefsTable() { + /// graph_id: Permanent >0 unique ID. + /// inserted_time: Float UNIX timestamp with µs precision. This is + /// always the wall time of when the row was inserted into the + /// DB. It may be used as a hint for an archival job. + /// node_def: Contains Snappy tf.GraphDef proto. All fields will be + /// cleared except those not expressed in SQL. + Status CreateGraphsTable() { return Run(R"sql( - CREATE TABLE IF NOT EXISTS NodeDefs ( + CREATE TABLE IF NOT EXISTS Graphs ( rowid INTEGER PRIMARY KEY, - experiment_id INTEGER, - node_def_id INTEGER NOT NULL, - fingerprint INTEGER, - node_def TEXT + graph_id INTEGER NOT NULL, + inserted_time REAL, + graph_def BLOB ) )sql"); } - /// \brief Creates RunNodeDefs table. - /// - /// Table mapping Runs to NodeDefs. This is used to recreate the node - /// field of the GraphDef proto. + /// \brief Creates Nodes table. /// /// Fields: /// rowid: Ephemeral b-tree ID dictating locality. - /// run_id: Mandatory ID of associated Run. - /// node_def_id: Mandatory ID of associated NodeDef. - Status CreateRunNodeDefsTable() { + /// graph_id: Permanent >0 unique ID. + /// node_id: ID for this node. This is more like a 0-index within + /// the Graph. Please note indexes are allowed to be removed. + /// node_name: Unique name for this Node within Graph. This is + /// copied from the proto so it can be indexed. This is allowed + /// to be NULL to save space on the index, in which case the + /// node_def.name proto field must not be cleared. + /// op: Copied from tf.NodeDef proto. + /// device: Copied from tf.NodeDef proto. + /// node_def: Contains Snappy tf.NodeDef proto. All fields will be + /// cleared except those not expressed in SQL. + Status CreateNodesTable() { return Run(R"sql( - CREATE TABLE IF NOT EXISTS RunNodeDefs ( + CREATE TABLE IF NOT EXISTS Nodes ( rowid INTEGER PRIMARY KEY, - run_id INTEGER NOT NULL, - node_def_id INTEGER NOT NULL + graph_id INTEGER NOT NULL, + node_id INTEGER NOT NULL, + node_name TEXT, + op TEXT, + device TEXT, + node_def BLOB + ) + )sql"); + } + + /// \brief Creates NodeInputs table. + /// + /// Fields: + /// rowid: Ephemeral b-tree ID dictating locality. + /// graph_id: Permanent >0 unique ID. + /// node_id: Index of Node in question. This can be considered the + /// 'to' vertex. + /// idx: Used for ordering inputs on a given Node. + /// input_node_id: Nodes.node_id of the corresponding input node. + /// This can be considered the 'from' vertex. + /// is_control: If non-zero, indicates this input is a controlled + /// dependency, which means this isn't an edge through which + /// tensors flow. NULL means 0. + Status CreateNodeInputsTable() { + return Run(R"sql( + CREATE TABLE IF NOT EXISTS NodeInputs ( + rowid INTEGER PRIMARY KEY, + graph_id INTEGER NOT NULL, + node_id INTEGER NOT NULL, + idx INTEGER NOT NULL, + input_node_id INTEGER NOT NULL, + is_control INTEGER ) )sql"); } @@ -297,11 +328,27 @@ class SqliteSchema { )sql"); } - /// \brief Uniquely indexes node_def_id on NodeDefs table. - Status CreateNodeDefIdIndex() { + /// \brief Uniquely indexes graph_id on Graphs table. + Status CreateGraphIdIndex() { return Run(R"sql( - CREATE UNIQUE INDEX IF NOT EXISTS NodeDefIdIndex - ON NodeDefs (node_def_id) + CREATE UNIQUE INDEX IF NOT EXISTS GraphIdIndex + ON Graphs (graph_id) + )sql"); + } + + /// \brief Uniquely indexes (graph_id, node_id) on Nodes table. + Status CreateNodeIdIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS NodeIdIndex + ON Nodes (graph_id, node_id) + )sql"); + } + + /// \brief Uniquely indexes (graph_id, node_id, idx) on NodeInputs table. + Status CreateNodeInputsIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS NodeInputsIndex + ON NodeInputs (graph_id, node_id, idx) )sql"); } @@ -350,20 +397,12 @@ class SqliteSchema { )sql"); } - /// \brief Indexes (experiment_id, fingerprint) on NodeDefs table. - Status CreateNodeDefFingerprintIndex() { + /// \brief Uniquely indexes (graph_id, node_name) on Nodes table. + Status CreateNodeNameIndex() { return Run(R"sql( - CREATE INDEX IF NOT EXISTS NodeDefFingerprintIndex - ON NodeDefs (experiment_id, fingerprint) - WHERE fingerprint IS NOT NULL - )sql"); - } - - /// \brief Uniquely indexes (run_id, node_def_id) on RunNodeDefs table. - Status CreateRunNodeDefIndex() { - return Run(R"sql( - CREATE UNIQUE INDEX IF NOT EXISTS RunNodeDefIndex - ON RunNodeDefs (run_id, node_def_id) + CREATE UNIQUE INDEX IF NOT EXISTS NodeNameIndex + ON Nodes (graph_id, node_name) + WHERE node_name IS NOT NULL )sql"); } @@ -387,22 +426,24 @@ Status SetupTensorboardSqliteDb(std::shared_ptr db) { TF_RETURN_IF_ERROR(s.CreateRunsTable()); TF_RETURN_IF_ERROR(s.CreateExperimentsTable()); TF_RETURN_IF_ERROR(s.CreateUsersTable()); - TF_RETURN_IF_ERROR(s.CreateNodeDefsTable()); - TF_RETURN_IF_ERROR(s.CreateRunNodeDefsTable()); + TF_RETURN_IF_ERROR(s.CreateGraphsTable()); + TF_RETURN_IF_ERROR(s.CreateNodeInputsTable()); + TF_RETURN_IF_ERROR(s.CreateNodesTable()); TF_RETURN_IF_ERROR(s.CreateTensorIndex()); TF_RETURN_IF_ERROR(s.CreateTensorChunkIndex()); TF_RETURN_IF_ERROR(s.CreateTagIdIndex()); TF_RETURN_IF_ERROR(s.CreateRunIdIndex()); TF_RETURN_IF_ERROR(s.CreateExperimentIdIndex()); TF_RETURN_IF_ERROR(s.CreateUserIdIndex()); - TF_RETURN_IF_ERROR(s.CreateNodeDefIdIndex()); + TF_RETURN_IF_ERROR(s.CreateGraphIdIndex()); + TF_RETURN_IF_ERROR(s.CreateNodeIdIndex()); + TF_RETURN_IF_ERROR(s.CreateNodeInputsIndex()); TF_RETURN_IF_ERROR(s.CreateTagNameIndex()); TF_RETURN_IF_ERROR(s.CreateRunNameIndex()); TF_RETURN_IF_ERROR(s.CreateExperimentNameIndex()); TF_RETURN_IF_ERROR(s.CreateUserNameIndex()); TF_RETURN_IF_ERROR(s.CreateUserEmailIndex()); - TF_RETURN_IF_ERROR(s.CreateNodeDefFingerprintIndex()); - TF_RETURN_IF_ERROR(s.CreateRunNodeDefIndex()); + TF_RETURN_IF_ERROR(s.CreateNodeNameIndex()); return Status::OK(); } diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc index a26ad616603..ae063d24efe 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc @@ -15,17 +15,29 @@ limitations under the License. #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" #include "tensorflow/contrib/tensorboard/db/schema.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/db/sqlite.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/snappy.h" #include "tensorflow/core/util/event.pb.h" namespace tensorflow { namespace { +double GetWallTime(Env* env) { + // TODO(@jart): Follow precise definitions for time laid out in schema. + // TODO(@jart): Use monotonic clock from gRPC codebase. + return static_cast(env->NowMicros()) / 1.0e6; +} + int64 MakeRandomId() { + // TODO(@jart): Try generating ID in 2^24 space, falling back to 2^63 + // https://sqlite.org/src4/doc/trunk/www/varint.wiki int64 id = static_cast(random::New64() & ((1ULL << 63) - 1)); if (id == 0) { ++id; @@ -33,10 +45,201 @@ int64 MakeRandomId() { return id; } +Status Serialize(const protobuf::MessageLite& proto, string* output) { + output->clear(); + if (!proto.SerializeToString(output)) { + return errors::DataLoss("SerializeToString failed"); + } + return Status::OK(); +} + +Status Compress(const string& data, string* output) { + output->clear(); + if (!port::Snappy_Compress(data.data(), data.size(), output)) { + return errors::FailedPrecondition("TensorBase needs Snappy"); + } + return Status::OK(); +} + +Status BindProto(SqliteStatement* stmt, int parameter, + const protobuf::MessageLite& proto) { + string serialized; + TF_RETURN_IF_ERROR(Serialize(proto, &serialized)); + string compressed; + TF_RETURN_IF_ERROR(Compress(serialized, &compressed)); + stmt->BindBlobUnsafe(parameter, compressed); + return Status::OK(); +} + +Status BindTensor(SqliteStatement* stmt, int parameter, const Tensor& t) { + // TODO(@jart): Make portable between little and big endian systems. + // TODO(@jart): Use TensorChunks with minimal copying for big tensors. + // TODO(@jart): Add field to indicate encoding. + // TODO(@jart): Allow crunch tool to re-compress with zlib instead. + TensorProto p; + t.AsProtoTensorContent(&p); + return BindProto(stmt, parameter, p); +} + +class Transactor { + public: + explicit Transactor(std::shared_ptr db) + : db_(std::move(db)), + begin_(db_->Prepare("BEGIN TRANSACTION")), + commit_(db_->Prepare("COMMIT TRANSACTION")), + rollback_(db_->Prepare("ROLLBACK TRANSACTION")) {} + + template + Status Transact(T callback, Args&&... args) { + TF_RETURN_IF_ERROR(begin_.StepAndReset()); + Status s = callback(std::forward(args)...); + if (s.ok()) { + TF_RETURN_IF_ERROR(commit_.StepAndReset()); + } else { + TF_RETURN_WITH_CONTEXT_IF_ERROR(rollback_.StepAndReset(), s.ToString()); + } + return s; + } + + private: + std::shared_ptr db_; + SqliteStatement begin_; + SqliteStatement commit_; + SqliteStatement rollback_; +}; + +class GraphSaver { + public: + static Status SaveToRun(Env* env, Sqlite* db, GraphDef* graph, int64 run_id) { + auto get = db->Prepare("SELECT graph_id FROM Runs WHERE run_id = ?"); + get.BindInt(1, run_id); + bool is_done; + TF_RETURN_IF_ERROR(get.Step(&is_done)); + int64 graph_id = is_done ? 0 : get.ColumnInt(0); + if (graph_id == 0) { + graph_id = MakeRandomId(); + // TODO(@jart): Check for ID collision. + auto set = db->Prepare("UPDATE Runs SET graph_id = ? WHERE run_id = ?"); + set.BindInt(1, graph_id); + set.BindInt(2, run_id); + TF_RETURN_IF_ERROR(set.StepAndReset()); + } + return Save(env, db, graph, graph_id); + } + + static Status Save(Env* env, Sqlite* db, GraphDef* graph, int64 graph_id) { + GraphSaver saver{env, db, graph, graph_id}; + saver.MapNameToNodeId(); + TF_RETURN_IF_ERROR(saver.SaveNodeInputs()); + TF_RETURN_IF_ERROR(saver.SaveNodes()); + TF_RETURN_IF_ERROR(saver.SaveGraph()); + return Status::OK(); + } + + private: + GraphSaver(Env* env, Sqlite* db, GraphDef* graph, int64 graph_id) + : env_(env), db_(db), graph_(graph), graph_id_(graph_id) {} + + void MapNameToNodeId() { + size_t toto = static_cast(graph_->node_size()); + name_copies_.reserve(toto); + name_to_node_id_.reserve(toto); + for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { + // Copy name into memory region, since we call clear_name() later. + // Then wrap in StringPiece so we can compare slices without copy. + name_copies_.emplace_back(graph_->node(node_id).name()); + name_to_node_id_.emplace(name_copies_.back(), node_id); + } + } + + Status SaveNodeInputs() { + auto purge = db_->Prepare("DELETE FROM NodeInputs WHERE graph_id = ?"); + purge.BindInt(1, graph_id_); + TF_RETURN_IF_ERROR(purge.StepAndReset()); + auto insert = db_->Prepare(R"sql( + INSERT INTO NodeInputs (graph_id, node_id, idx, input_node_id, is_control) + VALUES (?, ?, ?, ?, ?) + )sql"); + for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { + const NodeDef& node = graph_->node(node_id); + for (int idx = 0; idx < node.input_size(); ++idx) { + StringPiece name = node.input(idx); + insert.BindInt(1, graph_id_); + insert.BindInt(2, node_id); + insert.BindInt(3, idx); + if (!name.empty() && name[0] == '^') { + name.remove_prefix(1); + insert.BindInt(5, 1); + } + auto e = name_to_node_id_.find(name); + if (e == name_to_node_id_.end()) { + return errors::DataLoss("Could not find node: ", name); + } + insert.BindInt(4, e->second); + TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node.name(), + " -> ", name); + } + } + return Status::OK(); + } + + Status SaveNodes() { + auto purge = db_->Prepare("DELETE FROM Nodes WHERE graph_id = ?"); + purge.BindInt(1, graph_id_); + TF_RETURN_IF_ERROR(purge.StepAndReset()); + auto insert = db_->Prepare(R"sql( + INSERT INTO Nodes (graph_id, node_id, node_name, op, device, node_def) + VALUES (?, ?, ?, ?, ?, ?) + )sql"); + for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { + NodeDef* node = graph_->mutable_node(node_id); + insert.BindInt(1, graph_id_); + insert.BindInt(2, node_id); + insert.BindText(3, node->name()); + node->clear_name(); + if (!node->op().empty()) { + insert.BindText(4, node->op()); + node->clear_op(); + } + if (!node->device().empty()) { + insert.BindText(5, node->device()); + node->clear_device(); + } + node->clear_input(); + TF_RETURN_IF_ERROR(BindProto(&insert, 6, *node)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node->name()); + } + return Status::OK(); + } + + Status SaveGraph() { + auto insert = db_->Prepare(R"sql( + INSERT OR REPLACE INTO Graphs (graph_id, inserted_time, graph_def) + VALUES (?, ?, ?) + )sql"); + insert.BindInt(1, graph_id_); + insert.BindDouble(2, GetWallTime(env_)); + graph_->clear_node(); + TF_RETURN_IF_ERROR(BindProto(&insert, 3, *graph_)); + return insert.StepAndReset(); + } + + Env* env_; + Sqlite* db_; + GraphDef* graph_; + int64 graph_id_; + std::vector name_copies_; + std::unordered_map name_to_node_id_; +}; + class SummaryDbWriter : public SummaryWriterInterface { public: SummaryDbWriter(Env* env, std::shared_ptr db) - : SummaryWriterInterface(), env_(env), db_(std::move(db)), run_id_(-1) {} + : SummaryWriterInterface(), + env_(env), + db_(std::move(db)), + txn_(db_), + run_id_{0LL} {} ~SummaryDbWriter() override {} Status Initialize(const string& experiment_name, const string& run_name, @@ -76,7 +279,7 @@ class SummaryDbWriter : public SummaryWriterInterface { // TODO(@jart): Check for random ID collisions without needing txn retry. insert_tensor_.BindInt(1, tag_id); insert_tensor_.BindInt(2, global_step); - insert_tensor_.BindDouble(3, GetWallTime()); + insert_tensor_.BindDouble(3, GetWallTime(env_)); switch (t.dtype()) { case DT_INT64: insert_tensor_.BindInt(4, t.scalar()()); @@ -85,22 +288,41 @@ class SummaryDbWriter : public SummaryWriterInterface { insert_tensor_.BindDouble(4, t.scalar()()); break; default: - TF_RETURN_IF_ERROR(BindTensor(t)); + TF_RETURN_IF_ERROR(BindTensor(&insert_tensor_, 4, t)); break; } return insert_tensor_.StepAndReset(); } - Status WriteEvent(std::unique_ptr e) override { + Status WriteGraph(int64 global_step, std::unique_ptr g) override { mutex_lock ml(mu_); TF_RETURN_IF_ERROR(InitializeParents()); - if (e->what_case() == Event::WhatCase::kSummary) { - const Summary& summary = e->summary(); - for (int i = 0; i < summary.value_size(); ++i) { - TF_RETURN_IF_ERROR(WriteSummary(e.get(), summary.value(i))); + return txn_.Transact(GraphSaver::SaveToRun, env_, db_.get(), g.get(), + run_id_); + } + + Status WriteEvent(std::unique_ptr e) override { + switch (e->what_case()) { + case Event::WhatCase::kSummary: { + mutex_lock ml(mu_); + TF_RETURN_IF_ERROR(InitializeParents()); + const Summary& summary = e->summary(); + for (int i = 0; i < summary.value_size(); ++i) { + TF_RETURN_IF_ERROR(WriteSummary(e.get(), summary.value(i))); + } + return Status::OK(); } + case Event::WhatCase::kGraphDef: { + std::unique_ptr graph{new GraphDef}; + if (!ParseProtoUnlimited(graph.get(), e->graph_def())) { + return errors::DataLoss("parse event.graph_def failed"); + } + return WriteGraph(e->step(), std::move(graph)); + } + default: + // TODO(@jart): Handle other stuff. + return Status::OK(); } - return Status::OK(); } Status WriteScalar(int64 global_step, Tensor t, const string& tag) override { @@ -136,33 +358,8 @@ class SummaryDbWriter : public SummaryWriterInterface { string DebugString() override { return "SummaryDbWriter"; } private: - double GetWallTime() { - // TODO(@jart): Follow precise definitions for time laid out in schema. - // TODO(@jart): Use monotonic clock from gRPC codebase. - return static_cast(env_->NowMicros()) / 1.0e6; - } - - Status BindTensor(const Tensor& t) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - // TODO(@jart): Make portable between little and big endian systems. - // TODO(@jart): Use TensorChunks with minimal copying for big tensors. - TensorProto p; - t.AsProtoTensorContent(&p); - string encoded; - if (!p.SerializeToString(&encoded)) { - return errors::DataLoss("SerializeToString failed"); - } - // TODO(@jart): Put byte at beginning of blob to indicate encoding. - // TODO(@jart): Allow crunch tool to re-compress with zlib instead. - string compressed; - if (!port::Snappy_Compress(encoded.data(), encoded.size(), &compressed)) { - return errors::FailedPrecondition("TensorBase needs Snappy"); - } - insert_tensor_.BindBlobUnsafe(4, compressed); - return Status::OK(); - } - Status InitializeParents() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (run_id_ >= 0) { + if (run_id_ > 0) { return Status::OK(); } int64 user_id; @@ -195,7 +392,7 @@ class SummaryDbWriter : public SummaryWriterInterface { )sql"); insert_user.BindInt(1, *user_id); insert_user.BindText(2, user_name); - insert_user.BindDouble(3, GetWallTime()); + insert_user.BindDouble(3, GetWallTime(env_)); TF_RETURN_IF_ERROR(insert_user.StepAndReset()); } return Status::OK(); @@ -249,7 +446,7 @@ class SummaryDbWriter : public SummaryWriterInterface { } insert.BindInt(2, *id); insert.BindText(3, name); - insert.BindDouble(4, GetWallTime()); + insert.BindDouble(4, GetWallTime(env_)); TF_RETURN_IF_ERROR(insert.StepAndReset()); } return Status::OK(); @@ -276,6 +473,7 @@ class SummaryDbWriter : public SummaryWriterInterface { mutex mu_; Env* env_; std::shared_ptr db_ GUARDED_BY(mu_); + Transactor txn_ GUARDED_BY(mu_); SqliteStatement insert_tensor_ GUARDED_BY(mu_); SqliteStatement update_metadata_ GUARDED_BY(mu_); string user_name_ GUARDED_BY(mu_); diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc index c1af51e7b7a..3431842ca21 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/db/sqlite.h" @@ -212,5 +214,81 @@ TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) { kTolerance); } +TEST_F(SummaryDbWriterTest, WriteGraph) { + TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "R", "", &env_, &writer_)); + env_.AdvanceByMillis(23); + GraphDef graph; + NodeDef* node = graph.add_node(); + node->set_name("x"); + node->set_op("Placeholder"); + node = graph.add_node(); + node->set_name("y"); + node->set_op("Placeholder"); + node = graph.add_node(); + node->set_name("z"); + node->set_op("Love"); + node = graph.add_node(); + node->set_name("+"); + node->set_op("Add"); + node->add_input("x"); + node->add_input("y"); + node->add_input("^z"); + node->set_device("tpu/lol"); + std::unique_ptr e{new Event}; + graph.SerializeToString(e->mutable_graph_def()); + TF_ASSERT_OK(writer_->WriteEvent(std::move(e))); + TF_ASSERT_OK(writer_->Flush()); + ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs")); + ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Graphs")); + ASSERT_EQ(4LL, QueryInt("SELECT COUNT(*) FROM Nodes")); + ASSERT_EQ(3LL, QueryInt("SELECT COUNT(*) FROM NodeInputs")); + + int64 graph_id = QueryInt("SELECT graph_id FROM Graphs"); + EXPECT_GT(graph_id, 0LL); + EXPECT_EQ(graph_id, QueryInt("SELECT graph_id FROM Runs")); + EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Graphs")); + EXPECT_FALSE(QueryString("SELECT graph_def FROM Graphs").empty()); + + EXPECT_EQ("x", QueryString("SELECT node_name FROM Nodes WHERE node_id = 0")); + EXPECT_EQ("y", QueryString("SELECT node_name FROM Nodes WHERE node_id = 1")); + EXPECT_EQ("z", QueryString("SELECT node_name FROM Nodes WHERE node_id = 2")); + EXPECT_EQ("+", QueryString("SELECT node_name FROM Nodes WHERE node_id = 3")); + + EXPECT_EQ("Placeholder", + QueryString("SELECT op FROM Nodes WHERE node_id = 0")); + EXPECT_EQ("Placeholder", + QueryString("SELECT op FROM Nodes WHERE node_id = 1")); + EXPECT_EQ("Love", QueryString("SELECT op FROM Nodes WHERE node_id = 2")); + EXPECT_EQ("Add", QueryString("SELECT op FROM Nodes WHERE node_id = 3")); + + EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 0")); + EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 1")); + EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 2")); + EXPECT_EQ("tpu/lol", + QueryString("SELECT device FROM Nodes WHERE node_id = 3")); + + EXPECT_EQ(graph_id, + QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 0")); + EXPECT_EQ(graph_id, + QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 1")); + EXPECT_EQ(graph_id, + QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 2")); + + EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 0")); + EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 1")); + EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 2")); + + EXPECT_EQ(0LL, + QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 0")); + EXPECT_EQ(1LL, + QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 1")); + EXPECT_EQ(2LL, + QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 2")); + + EXPECT_EQ(0LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 0")); + EXPECT_EQ(0LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 1")); + EXPECT_EQ(1LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 2")); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py index 6a5fe06ff07..ec4c4e1be6f 100644 --- a/tensorflow/contrib/tpu/__init__.py +++ b/tensorflow/contrib/tpu/__init__.py @@ -24,7 +24,6 @@ @@initialize_system @@shutdown_system @@core -@@outside_all_rewrites @@replicate @@shard @@batch_parallel diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto index d8ee2437909..2d2207a43fe 100644 --- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto +++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto @@ -124,4 +124,6 @@ message TfOpStats { optional LoopingResult looping = 4; // The result for the HloExtraInfoMap. optional HloExtraInfoMapResult hlo_extrainfo_map = 5; + // Overall matrix unit utilization in percentage. + optional double matrix_unit_utilization_percent = 6; } diff --git a/tensorflow/contrib/tpu/python/tpu/test_util.py b/tensorflow/contrib/tpu/python/tpu/test_util.py index b83c72d0ffe..a5d4ff97227 100644 --- a/tensorflow/contrib/tpu/python/tpu/test_util.py +++ b/tensorflow/contrib/tpu/python/tpu/test_util.py @@ -32,6 +32,7 @@ from tensorflow.python.client import session as tf_session from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_util from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import variables @@ -89,7 +90,12 @@ def copy_dir(src, tgt): gfile.Copy(src_f, tgt_f, overwrite=True) -def compare_model(model_fn, input_fn, params, master="local", temp_dir=None, +def compare_model(model_fn, + input_fn, + params, + master="local", + temp_dir=None, + num_shards=2, tolerance=1e-4): """Compare the results of running `model_fn` on the TPU and CPU.""" if not temp_dir: @@ -102,7 +108,17 @@ def compare_model(model_fn, input_fn, params, master="local", temp_dir=None, logging.info("Checkpoints and weights will be written to %s", temp_dir) num_steps = 1 - num_shards = 8 + + def _model_adapter(features, labels, mode, params): + """Run users model function with random seeds fixed to known values.""" + random_seed.set_random_seed(0) + np.random.seed(0) + return model_fn(features, labels, mode, params) + + def _input_adapter(params): + random_seed.set_random_seed(0) + np.random.seed(0) + return input_fn(params) def _make_run_config(model_dir): return tpu_config.RunConfig( @@ -119,7 +135,7 @@ def compare_model(model_fn, input_fn, params, master="local", temp_dir=None, def _make_estimator(use_tpu, model_dir): return tpu_estimator.TPUEstimator( - model_fn=model_fn, + model_fn=_model_adapter, use_tpu=use_tpu, config=_make_run_config(model_dir), train_batch_size=num_shards, @@ -131,8 +147,9 @@ def compare_model(model_fn, input_fn, params, master="local", temp_dir=None, weights = {} graph = ops.Graph() with graph.as_default(): + features, labels = _input_adapter(dict(params, batch_size=num_shards)) model_fn( - *input_fn(params), + features, labels, params=dict(params, use_tpu=False), mode=model_fn_lib.ModeKeys.TRAIN) saver = tf_saver.Saver() @@ -148,10 +165,15 @@ def compare_model(model_fn, input_fn, params, master="local", temp_dir=None, return weights def _run_step(use_tpu, model_dir): + """Create an estimator and run a single step on the given device.""" + tf_session.Session.reset(target=master) + + logging.info("Running step. TPU=%d. model_dir=%s", use_tpu, model_dir) est = _make_estimator(use_tpu=use_tpu, model_dir=model_dir) - est.train(input_fn=input_fn, steps=num_steps) + est.train(input_fn=_input_adapter, steps=num_steps) weights = _extract_weights(est.latest_checkpoint()) - with gfile.Open(temp_dir + "tpu-%d.weights" % use_tpu, "wb") as f: + with gfile.Open(os.path.join(temp_dir, "tpu-%d.weights" % use_tpu), + "wb") as f: f.write(pickle.dumps(weights)) return weights @@ -159,9 +181,9 @@ def compare_model(model_fn, input_fn, params, master="local", temp_dir=None, _run_step(use_tpu=False, model_dir=initial_model_dir) copy_dir(initial_model_dir, cpu_model_dir) - cpu_weights = _run_step(use_tpu=False, model_dir=cpu_model_dir) - copy_dir(initial_model_dir, tpu_model_dir) + + cpu_weights = _run_step(use_tpu=False, model_dir=cpu_model_dir) tpu_weights = _run_step(use_tpu=True, model_dir=tpu_model_dir) bad_weights = False diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index d521297d994..f3ddc097544 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -19,7 +19,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.tpu.python.ops import tpu_ops @@ -30,6 +29,11 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging + + +_SUMMARY_OPS = ("ScalarSummary",) +_PLACEHOLDER_OPS = ("Placeholder",) def initialize_system(embedding_config=None, job=None): @@ -81,26 +85,6 @@ def core(num): return "device:TPU_REPLICATED_CORE:{}".format(num) -# Experimental API to 'break out' of a tpu.rewrite() (or shard(), etc.) context. -# In -# -# XXX -# with tpu.rewrite(...): -# YYY -# with tpu.outside_all_rewrites(): -# ZZZ -# -# the Ops in ZZZ are added outside the scope of the rewrite(). -# TODO(phawkins): currently outside_all_rewrites() pops out of all nested -# control flow scopes, for example loops. It would make more sense if it only -# popped out of a single scope. -@contextlib.contextmanager -def outside_all_rewrites(): - """Experimental API to 'break out' of a tpu.rewrite() (or shard(), etc.).""" - with ops.control_dependencies(None): - yield - - class TPUReplicateContext(control_flow_ops.ControlFlowContext): """A ControlFlowContext for nodes inside a TPU computation. @@ -124,6 +108,13 @@ class TPUReplicateContext(control_flow_ops.ControlFlowContext): def _AddOpInternal(self, op): # pylint: disable=protected-access + if op.type in _PLACEHOLDER_OPS: + raise ValueError("Placeholder %s is not supported." % op.name) + + if op.type in _SUMMARY_OPS: + logging.warning( + "Summary operations are not currently supported (%s)" % op.name) + if any(x.dtype._is_ref_dtype for x in op.inputs): raise NotImplementedError( "Non-resource Variables are not supported inside TPU computations " diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 07877fcc761..97b2d25e0cf 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -946,23 +946,14 @@ class _InputPipeline(object): # user code, so, log a warning. if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS): err_msg = ('Input pipeline contains one or more QueueRunners. ' - 'These are not supported via TPUEstimator. You must convert ' - 'your input pipeline to use `tf.data` instead (see ' + 'It could be slow and not scalable. Please consider ' + 'converting your input pipeline to use `tf.data` instead (see ' 'https://www.tensorflow.org/programmers_guide/datasets for ' 'instructions.') if _WRAP_INPUT_FN_INTO_WHILE_LOOP: raise RuntimeError(err_msg) else: logging.warn(err_msg) - elif ops.get_default_graph().get_collection(ops.GraphKeys.SUMMARIES): - # Queue Runner has summary Ops by default. So here we use elif to do - # necessary checks for Dataset input pipeline only. - err_msg = ('Input pipeline contains `tf.summary` operations. ' - 'These are not currently supported.') - if _WRAP_INPUT_FN_INTO_WHILE_LOOP: - raise RuntimeError(err_msg) - else: - logging.warn(err_msg) class _ModelFnWrapper(object): diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index d95d958d5af..f222d345abe 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -272,7 +272,10 @@ void RunApiTest(bool update_api_def, const string& api_files_dir) { for (auto new_api_entry : new_api_defs_map) { const auto& file_path = new_api_entry.first; - const auto& golden_api_defs_str = golden_api_defs_map.at(file_path); + std::string golden_api_defs_str = ""; + if (golden_api_defs_map.find(file_path) != golden_api_defs_map.end()) { + golden_api_defs_str = golden_api_defs_map.at(file_path); + } string new_api_defs_str = new_api_entry.second.DebugString(); new_api_defs_str = PBTxtToMultiline(new_api_defs_str, multi_line_fields); if (golden_api_defs_str == new_api_defs_str) { diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h index 674111dbe69..3912cd177b6 100644 --- a/tensorflow/core/common_runtime/device.h +++ b/tensorflow/core/common_runtime/device.h @@ -110,12 +110,9 @@ class Device : public DeviceBase { // prototyping of TensorFlow device implementations that need to modify // the GraphDef before execution. // - // 'library' provides access to the function library which is shared - // between all device partitions. // 'graph' supplies the partition of the graph assigned to this // device. - virtual Status MaybeRewriteGraph(const FunctionDefLibrary& /*library*/, - std::unique_ptr* /*graph*/) { + virtual Status MaybeRewriteGraph(std::unique_ptr* /*graph*/) { return Status::OK(); } diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index d652b1004ff..2f57164dcd8 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -1419,11 +1419,7 @@ Status DirectSession::CreateGraphs( Device* d; s = device_mgr_->LookupDevice(partition_name, &d); if (!s.ok()) break; - // TODO(pbar) The library is currently shared and immutable. There - // may be possible use cases where a device may want to modify - // function definitions - in which case the library would need to be - // replicated per device. - s = d->MaybeRewriteGraph(client_graph->flib_def->ToProto(), graph); + s = d->MaybeRewriteGraph(graph); if (!s.ok()) { break; } diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h index 22a70fbdfae..3103ca07512 100644 --- a/tensorflow/core/common_runtime/renamed_device.h +++ b/tensorflow/core/common_runtime/renamed_device.h @@ -104,9 +104,8 @@ class RenamedDevice : public Device { Status Sync() override { return underlying_->Sync(); } - Status MaybeRewriteGraph(const FunctionDefLibrary& library, - std::unique_ptr* graph) override { - return underlying_->MaybeRewriteGraph(library, graph); + Status MaybeRewriteGraph(std::unique_ptr* graph) override { + return underlying_->MaybeRewriteGraph(graph); } Status FillContextMap(const Graph* graph, diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index 8e314c7ea57..10901da192f 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -129,80 +129,82 @@ Status InferShapesForFunctionSubNode(const Node* node, ShapeRefiner* refiner, // Maybe we won't support recursive functions at all in TF, because of // other maintanabilty issues. Status ShapeRefiner::InferShapesForFunction( - const tensorflow::FunctionLibraryDefinition& function_library, - const tensorflow::FunctionDef& function_def, bool keep_nested_shapes, + const tensorflow::FunctionDef* function_def, bool keep_nested_shapes, ExtendedInferenceContext* outer_context) { - InstantiationResult result; - TF_RETURN_IF_ERROR(InstantiateFunction( - function_def, outer_context->get_context()->attrs(), - [&function_library](const string& op, const OpDef** sig) { - return function_library.LookUpOpDef(op, sig); - }, - &result)); + const Graph* graph; + auto it = functions_.find(function_def); + if (it != functions_.end()) { + graph = it->second.get(); + } else { + InstantiationResult result; + TF_RETURN_IF_ERROR(InstantiateFunction( + *function_def, outer_context->get_context()->attrs(), + [this](const string& op, const OpDef** sig) { + return this->function_library_->LookUpOpDef(op, sig); + }, + &result)); - Graph graph(&function_library); - { + Graph* new_graph = new Graph(function_library_); GraphConstructorOptions options; options.allow_internal_ops = true; - TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(options, result.nodes, &graph)); + TF_RETURN_IF_ERROR( + ConvertNodeDefsToGraph(options, result.nodes, new_graph)); + functions_[function_def].reset(new_graph); + graph = new_graph; } - ShapeRefiner refiner(graph.versions().producer(), &function_library); - refiner.set_disable_constant_propagation(disable_constant_propagation_); - refiner.set_function_library_for_shape_inference(&function_library); - if (keep_nested_shapes) refiner.set_keep_nested_shape_inferences(); - + std::unordered_set function_nodes; + Status inference_status = Status::OK(); { - Status inference_status = Status::OK(); - auto node_shape_inference_lambda = [&refiner, &outer_context, + auto node_shape_inference_lambda = [this, &outer_context, &function_nodes, &inference_status](const Node* node) { if (!inference_status.ok()) return; inference_status = InferShapesForFunctionSubNode( - node, &refiner, outer_context->get_context()); + node, this, outer_context->get_context()); + function_nodes.insert(node); }; // Calls inference lambda for each node after visiting all predecessors. // Ensures that we are adding nodes to ShapeRefiner in the topological // order. - ReverseDFS(graph, {}, node_shape_inference_lambda); - - TF_RETURN_IF_ERROR(inference_status); + ReverseDFS(*graph, {}, node_shape_inference_lambda); } - if (keep_nested_shapes) { + if (keep_nested_shapes && inference_status.ok()) { // Fill the nested inferences map. // // The materialized function graph has extra nodes for arguments and // return values, which are not explicitly listed in the FunctionDef, // we filter out these special nodes here to not expose the implementation // details and keep only inferences for the nodes listed in the FunctionDef. - - auto stolen_contexts = refiner.StealInferenceContexts(); - std::unordered_map user_defined_nodes; - for (const auto& node_def : function_def.node_def()) { + for (const auto& node_def : function_def->node_def()) { user_defined_nodes[node_def.name()] = &node_def; } std::unordered_map> nested_inferences; - for (auto& stolen_kv : stolen_contexts) { - auto& stolen_name = stolen_kv.first->name(); - if (user_defined_nodes.find(stolen_name) != user_defined_nodes.end()) { - nested_inferences[stolen_name] = std::move(stolen_kv.second); - - // By default InferenceContext refers to a NodeDef from Graph, - // we have to change it to a NodeDef with longer lifetime, - // because the Graph is a temporary in this function. - nested_inferences[stolen_name]->get_context()->node_def_ = - user_defined_nodes[stolen_name]; + for (const Node* node : function_nodes) { + const string& node_name = node->name(); + if (user_defined_nodes.find(node_name) != user_defined_nodes.end()) { + nested_inferences[node_name] = std::move(node_to_context_[node]); + node_to_context_.erase(node); + // By default InferenceContext refers to a NodeDef from Graph. + // Change it to the publicly accessible NodeDef of the function + // definition. + nested_inferences[node_name]->get_context()->node_def_ = + user_defined_nodes[node_name]; } } - outer_context->set_nested_inferences(std::move(nested_inferences)); + } else { + // Delete the contexts created for the functions nodes to save memory. + for (const Node* node : function_nodes) { + node_to_context_.erase(node); + } } - return Status::OK(); + return inference_status; } Status ShapeRefiner::AddNode(const Node* node) { @@ -781,9 +783,8 @@ Status ShapeRefiner::RunShapeFn(const Node* node, auto* func_def = function_library_->Find(op_reg_data->op_def.name()); if (func_def) { - TF_RETURN_IF_ERROR(InferShapesForFunction( - *function_library_, *func_def, keep_nested_shape_inferences_, ec)); - return Status::OK(); + return InferShapesForFunction(func_def, keep_nested_shape_inferences_, + ec); } } diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h index 570b4db1635..da42c30ce94 100644 --- a/tensorflow/core/common_runtime/shape_refiner.h +++ b/tensorflow/core/common_runtime/shape_refiner.h @@ -159,6 +159,7 @@ class ShapeRefiner { // With this enabled, shape inference can take more time since it descends // into all function calls. It doesn't do inference once for each function // definition, but once for each function call. + // The function library must outlive the shape refiner. void set_function_library_for_shape_inference( const tensorflow::FunctionLibraryDefinition* lib) { function_library_ = lib; @@ -210,10 +211,9 @@ class ShapeRefiner { // - outer_context will contain output shapes inferred from input shapes // - outer_context will contain nested inferences collection, iff // keep_nested_shapes is true - Status InferShapesForFunction( - const tensorflow::FunctionLibraryDefinition& function_library, - const tensorflow::FunctionDef& function_def, bool keep_nested_shapes, - ExtendedInferenceContext* outer_context); + Status InferShapesForFunction(const tensorflow::FunctionDef* function_def, + bool keep_nested_shapes, + ExtendedInferenceContext* outer_context); // Tries to infer tensor output based on the input shapes of the node. In some // cases, the shapes of the inputs are sufficient for inferring the contents @@ -260,12 +260,6 @@ class ShapeRefiner { Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data, ExtendedInferenceContext* ec); - // Destructive operation, which steals ownership of inference contexts map. - std::unordered_map> - StealInferenceContexts() { - return std::move(node_to_context_); - } - int32 graph_def_version_; const OpRegistryInterface* const ops_registry_; @@ -299,6 +293,11 @@ class ShapeRefiner { // defined functions. By default that info is discarded to save memory. bool keep_nested_shape_inferences_ = false; + // Cache the graph corresponding to each functin definition for which shapes + // are refined. + std::unordered_map> + functions_; + TF_DISALLOW_COPY_AND_ASSIGN(ShapeRefiner); }; diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 391ffda25c0..60d58af61da 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -208,8 +208,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, } // Give the device an opportunity to rewrite its subgraph. - TF_RETURN_IF_ERROR( - unit->device->MaybeRewriteGraph(gdef.library(), &subgraph)); + TF_RETURN_IF_ERROR(unit->device->MaybeRewriteGraph(&subgraph)); // Top-level nodes in the graph uses the op segment to cache // kernels. Therefore, as long as the executor is alive, we need diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 5798ad09e81..91a1fa7d1e1 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -1044,6 +1044,7 @@ Status MasterSession::Create(GraphDef* graph_def, graph_def, execution_options, &execution_state_)); } if (options.cluster_def != nullptr) { + should_delete_worker_sessions_ = true; return CreateWorkerSessions(options); } return Status::OK(); @@ -1122,6 +1123,59 @@ Status MasterSession::CreateWorkerSessions( return status; } +Status MasterSession::DeleteWorkerSessions() { + WorkerCacheInterface* worker_cache = get_worker_cache(); + std::vector worker_names; + worker_cache->ListWorkers(&worker_names); + + struct WorkerGroup { + // The worker name. (Not owned.) + const string* name; + + // The worker referenced by name. (Not owned.) + WorkerInterface* worker = nullptr; + + // Request and responses used for a given worker. + DeleteWorkerSessionRequest request; + DeleteWorkerSessionResponse response; + Status status = Status::OK(); + }; + BlockingCounter done(worker_names.size()); + std::vector workers(worker_names.size()); + + // Release the workers. + auto cleanup = gtl::MakeCleanup([this, &workers, worker_cache] { + for (auto&& worker_group : workers) { + if (worker_group.worker != nullptr) { + worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker); + } + } + }); + + Status status = Status::OK(); + // Create all the workers & kick off the computations. + for (size_t i = 0; i < worker_names.size(); ++i) { + workers[i].name = &worker_names[i]; + workers[i].worker = worker_cache_->CreateWorker(worker_names[i]); + workers[i].request.set_session_handle(handle_); + } + + for (size_t i = 0; i < worker_names.size(); ++i) { + auto cb = [i, &workers, &done](const Status& s) { + workers[i].status = s; + done.DecrementCount(); + }; + workers[i].worker->DeleteWorkerSessionAsync(&workers[i].request, + &workers[i].response, cb); + } + + done.Wait(); + for (size_t i = 0; i < workers.size(); ++i) { + status.Update(workers[i].status); + } + return status; +} + Status MasterSession::ListDevices(ListDevicesResponse* resp) const { if (worker_cache_) { // This is a ClusterSpec-propagated session, and thus env_->local_devices @@ -1604,6 +1658,12 @@ Status MasterSession::Close() { ClearRunsTable(&to_unref, &partial_run_graphs_); } for (ReffedClientGraph* rcg : to_unref) rcg->Unref(); + if (should_delete_worker_sessions_) { + Status s = DeleteWorkerSessions(); + if (!s.ok()) { + LOG(WARNING) << s; + } + } return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h index eb696eb06a9..4bd4e1367aa 100644 --- a/tensorflow/core/distributed_runtime/master_session.h +++ b/tensorflow/core/distributed_runtime/master_session.h @@ -201,6 +201,10 @@ class MasterSession : public core::RefCounted { // workers. Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def); + // TODO(b/36574172): Always use Create/DeleteWorkerSession. + bool should_delete_worker_sessions_ = false; + Status DeleteWorkerSessions(); + Status StartStep(const BuildGraphOptions& opts, int64* count, ReffedClientGraph** graph, bool is_partial); void ClearRunsTable(std::vector* to_unref, diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc index 2b9798d413c..b3b05408b15 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc @@ -39,16 +39,15 @@ namespace tensorflow { class GrpcRemoteWorker : public WorkerInterface { public: - explicit GrpcRemoteWorker(GrpcCounter* live_rpc_counter, - SharedGrpcChannelPtr channel, + explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel, ::grpc::CompletionQueue* completion_queue, WorkerCacheLogger* logger) - : counter_(live_rpc_counter), - channel_(std::move(channel)), + : channel_(std::move(channel)), stub_(channel_), cq_(completion_queue), getstatus_(Method(GrpcWorkerMethod::kGetStatus)), createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)), + deleteworkersession_(Method(GrpcWorkerMethod::kDeleteWorkerSession)), registergraph_(Method(GrpcWorkerMethod::kRegisterGraph)), deregistergraph_(Method(GrpcWorkerMethod::kDeregisterGraph)), rungraph_(Method(GrpcWorkerMethod::kRunGraph)), @@ -73,6 +72,12 @@ class GrpcRemoteWorker : public WorkerInterface { IssueRequest(request, response, createworkersession_, std::move(done)); } + void DeleteWorkerSessionAsync(const DeleteWorkerSessionRequest* request, + DeleteWorkerSessionResponse* response, + StatusCallback done) override { + IssueRequest(request, response, deleteworkersession_, std::move(done)); + } + void RegisterGraphAsync(const RegisterGraphRequest* request, RegisterGraphResponse* response, StatusCallback done) override { @@ -182,27 +187,26 @@ class GrpcRemoteWorker : public WorkerInterface { void IssueRequest(const protobuf::Message* request, protobuf::Message* response, const ::grpc::string& method, StatusCallback done, CallOptions* call_opts = nullptr) { - new RPCState(counter_, &stub_, cq_, method, *request, - response, std::move(done), call_opts); + new RPCState(&stub_, cq_, method, *request, response, + std::move(done), call_opts); } void IssueRequest(const protobuf::Message* request, TensorResponse* response, const ::grpc::string& method, StatusCallback done, CallOptions* call_opts = nullptr) { - new RPCState(counter_, &stub_, cq_, method, *request, - response, std::move(done), call_opts); + new RPCState(&stub_, cq_, method, *request, response, + std::move(done), call_opts); } // Helper function for initializing the RpcMethod objects below. const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); } - GrpcCounter* const counter_; SharedGrpcChannelPtr channel_; ::grpc::GenericStub stub_; - ::grpc::CompletionQueue* cq_; const ::grpc::string getstatus_; const ::grpc::string createworkersession_; + const ::grpc::string deleteworkersession_; const ::grpc::string registergraph_; const ::grpc::string deregistergraph_; const ::grpc::string rungraph_; @@ -218,12 +222,10 @@ class GrpcRemoteWorker : public WorkerInterface { TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker); }; -WorkerInterface* NewGrpcRemoteWorker(GrpcCounter* live_rpc_counter, - SharedGrpcChannelPtr channel, +WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel, ::grpc::CompletionQueue* completion_queue, WorkerCacheLogger* logger) { - return new GrpcRemoteWorker(live_rpc_counter, std::move(channel), - completion_queue, logger); + return new GrpcRemoteWorker(std::move(channel), completion_queue, logger); } } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h index 174dfcc7072..8ad41335409 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h @@ -26,12 +26,10 @@ class CompletionQueue; namespace tensorflow { -class GrpcCounter; class WorkerCacheLogger; class WorkerInterface; -WorkerInterface* NewGrpcRemoteWorker(GrpcCounter* live_rpc_counter, - SharedGrpcChannelPtr channel, +WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel, ::grpc::CompletionQueue* completion_queue, WorkerCacheLogger* logger); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_state.h b/tensorflow/core/distributed_runtime/rpc/grpc_state.h index 087b49ba765..3f80bdfb70d 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_state.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_state.h @@ -34,24 +34,18 @@ template class RPCState : public GrpcClientCQTag { public: // Default behavior is to set fail_fast = False and handle timeouts manually. - RPCState(GrpcCounter* counter, ::grpc::GenericStub* stub, - ::grpc::CompletionQueue* cq, const ::grpc::string& method, - const protobuf::Message& request, Response* response, - StatusCallback done, CallOptions* call_opts) - : RPCState(counter, stub, cq, method, request, response, std::move(done), + RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq, + const ::grpc::string& method, const protobuf::Message& request, + Response* response, StatusCallback done, CallOptions* call_opts) + : RPCState(stub, cq, method, request, response, std::move(done), call_opts, /*fail_fast=*/false, /*timeout_in_ms=*/0) {} template - RPCState(GrpcCounter* counter, ::grpc::GenericStub* stub, - ::grpc::CompletionQueue* cq, const ::grpc::string& method, - const Request& request, Response* response, StatusCallback done, - CallOptions* call_opts, bool fail_fast, int64 timeout_in_ms) - : counter_(counter), call_opts_(call_opts), done_(std::move(done)) { - // TODO(sanjay): The counter will no longer be needed once we - // get a GenericStub API which allows us to manage an entire - // RPC with a single completion event instead of four events. - counter_->Increment(); - + RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq, + const ::grpc::string& method, const Request& request, + Response* response, StatusCallback done, CallOptions* call_opts, + bool fail_fast, int64 timeout_in_ms) + : call_opts_(call_opts), done_(std::move(done)) { context_.set_fail_fast(fail_fast); if (timeout_in_ms > 0) { context_.set_deadline(gpr_time_from_millis(timeout_in_ms, GPR_TIMESPAN)); @@ -61,84 +55,43 @@ class RPCState : public GrpcClientCQTag { call_opts->SetCancelCallback([this]() { context_.TryCancel(); }); } - failure_.store(false); - remaining_callbacks_.store(4); // Init/Read/Write/Finish callbacks response_ = response; GrpcMaybeUnparseProto(request, &request_buf_); - // TODO(sanjay): When new enough grpc is available, enable the following: - // context_.set_initial_metadata_corked(true); - // We can then skip the extra state transition for init callback. - call_ = std::move(stub->Call(&context_, method, cq, this)); - call_initialized_.Notify(); + call_ = + std::move(stub->PrepareUnaryCall(&context_, method, request_buf_, cq)); + call_->StartCall(); + call_->Finish(&response_buf_, &status_, this); } - // Called multiple times: when init done, read done, write done, call done. void OnCompleted(bool ok) override { - if (!ok) failure_.store(true); - const int old_count = remaining_callbacks_.fetch_sub(1); - if (old_count > 1) { - if (old_count == 4) { - // Init callback finished. Issue remaining ops. - - // Annoyingly enough, the way the generic call API works is - // inherently racy. We can get the following sequence of events: - // 1. stub->Call() starts. - // 2. some stuff happens inside grpc - // 3. grpc delivers the completion event - // 4. tensorflow event handling thread calls init metadata callback - // 5. stub->Call() finishes - // 6. the result of stub->Call() is stored in call_ - // We are currently inside the callback and therefore need to - // wait for step 6 to finish before attempting to touch call_. - call_initialized_.WaitForNotification(); - - if (ok) { - // TODO(sanjay): Use WriteLast() when grpc version we are using - // is new enough. - call_->Write(request_buf_, this); - call_->Read(&response_buf_, this); - } else { - // Skip Write and Read. - remaining_callbacks_.fetch_sub(2); - } - call_->Finish(&status_, this); - } - // Still waiting for some more callbacks to finish. - return; - } else { // old_count == 1, i.e., all callbacks have finished - // Last callback finished; clean up. - if (call_opts_) { - call_opts_->ClearCancelCallback(); - } - Status s = FromGrpcStatus(status_); - if (s.ok() && failure_.load()) { - s.Update(errors::Internal("callback error")); - } - if (s.ok() && !GrpcMaybeParseProto(response_buf_, response_)) { - s.Update(errors::Internal("could not parse rpc response")); - } - if (!s.ok()) { - VLOG(2) << "Call returned with non-ok status: " << s; - } - done_(s); - counter_->Decrement(); - delete this; + if (call_opts_) { + call_opts_->ClearCancelCallback(); } + Status s = FromGrpcStatus(status_); + if (s.ok() && !ok) { + // Since this function is only being used for processing the response + // to Finish for client-side unary calls, ok should never be false + s.Update(errors::Internal("unexpected ok value at rpc completion")); + } + if (s.ok() && !GrpcMaybeParseProto(response_buf_, response_)) { + s.Update(errors::Internal("could not parse rpc response")); + } + if (!s.ok()) { + VLOG(2) << "Call returned with non-ok status: " << s; + } + done_(s); + delete this; } private: - GrpcCounter* const counter_; CallOptions* call_opts_; ::grpc::ClientContext context_; - std::unique_ptr<::grpc::GenericClientAsyncReaderWriter> call_; + std::unique_ptr<::grpc::GenericClientAsyncResponseReader> call_; Response* response_; ::grpc::ByteBuffer request_buf_; ::grpc::ByteBuffer response_buf_; ::grpc::Status status_; StatusCallback done_; - std::atomic failure_; - std::atomic remaining_callbacks_; - Notification call_initialized_; }; } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_util.cc b/tensorflow/core/distributed_runtime/rpc/grpc_util.cc index 9a97978c503..c80728544b0 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_util.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_util.cc @@ -135,25 +135,4 @@ bool GrpcMaybeParseProto(const grpc::ByteBuffer& src, string* dst) { return true; } -void GrpcCounter::Increment() { - mutex_lock l(mu_); - counter_++; -} - -void GrpcCounter::Decrement() { - mutex_lock l(mu_); - DCHECK_GT(counter_, 0); - counter_--; - if (counter_ == 0) { - empty_.notify_all(); - } -} - -void GrpcCounter::WaitUntilUnused() { - mutex_lock l(mu_); - while (counter_ != 0) { - empty_.wait(l); - } -} - } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_util.h b/tensorflow/core/distributed_runtime/rpc/grpc_util.h index 04a54e672cb..0ddcd89130b 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_util.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_util.h @@ -84,29 +84,6 @@ class GrpcByteBufferSource : public ::grpc::protobuf::io::ZeroCopyInputStream { ::grpc::protobuf::int64 byte_count_; }; -// GrpcCounter is used to delay shutdown until all active RPCs are done. -class GrpcCounter { - public: - GrpcCounter() {} - - GrpcCounter(const GrpcCounter&) = delete; - GrpcCounter& operator=(const GrpcCounter&) = delete; - - // Increment the count of live RPCs. - void Increment(); - - // Decrement the count of live RPCs. - void Decrement(); - - // Wait until count of live RPCs is zero. - void WaitUntilUnused(); - - private: - mutex mu_; - condition_variable empty_; - int counter_ = 0; -}; - } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc index 06695db7790..a7b93e04607 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc @@ -51,9 +51,6 @@ class GrpcWorkerCache : public WorkerCachePartial { // Explicit destructor to control destruction order. ~GrpcWorkerCache() override { - // Wait until all live rpcs are done since otherwise the completion - // queue shutdown will interfere with rpc operation. - live_rpc_counter_.WaitUntilUnused(); completion_queue_.Shutdown(); delete polling_thread_; // Blocks until thread exits. delete channel_cache_; @@ -69,8 +66,7 @@ class GrpcWorkerCache : public WorkerCachePartial { } else { SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target); if (!channel) return nullptr; - return NewGrpcRemoteWorker(&live_rpc_counter_, channel, - &completion_queue_, &logger_); + return NewGrpcRemoteWorker(channel, &completion_queue_, &logger_); } } @@ -94,7 +90,6 @@ class GrpcWorkerCache : public WorkerCachePartial { private: const string local_target_; WorkerInterface* const local_worker_; // Not owned. - GrpcCounter live_rpc_counter_; GrpcChannelCache* channel_cache_; // Owned. ::grpc::CompletionQueue completion_queue_; Thread* polling_thread_; // Owned. diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index 4ee5ae09017..eee93ec6572 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -114,6 +114,7 @@ class GrpcWorkerService : public AsyncServiceInterface { // types. ENQUEUE_REQUEST(GetStatus, false); ENQUEUE_REQUEST(CreateWorkerSession, false); + ENQUEUE_REQUEST(DeleteWorkerSession, false); ENQUEUE_REQUEST(CleanupAll, false); ENQUEUE_REQUEST(RegisterGraph, false); ENQUEUE_REQUEST(DeregisterGraph, false); @@ -192,6 +193,16 @@ class GrpcWorkerService : public AsyncServiceInterface { ENQUEUE_REQUEST(CreateWorkerSession, false); } + void DeleteWorkerSessionHandler( + WorkerCall* + call) { + Schedule([this, call]() { + Status s = worker_->DeleteWorkerSession(&call->request, &call->response); + call->SendResponse(ToGrpcStatus(s)); + }); + ENQUEUE_REQUEST(DeleteWorkerSession, false); + } + void CleanupAllHandler( WorkerCall* call) { Schedule([this, call]() { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc index 348c6dc98bd..05a9db10d3c 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc @@ -32,6 +32,8 @@ const char* GrpcWorkerMethodName(GrpcWorkerMethod id) { return "/tensorflow.WorkerService/GetStatus"; case GrpcWorkerMethod::kCreateWorkerSession: return "/tensorflow.WorkerService/CreateWorkerSession"; + case GrpcWorkerMethod::kDeleteWorkerSession: + return "/tensorflow.WorkerService/DeleteWorkerSession"; case GrpcWorkerMethod::kRegisterGraph: return "/tensorflow.WorkerService/RegisterGraph"; case GrpcWorkerMethod::kDeregisterGraph: diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h index e9862a61a3f..fb23f8631fd 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h @@ -110,6 +110,7 @@ namespace tensorflow { enum class GrpcWorkerMethod { kGetStatus, kCreateWorkerSession, + kDeleteWorkerSession, kRegisterGraph, kDeregisterGraph, kRunGraph, diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index fcb18301970..8bf87923ed4 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -48,6 +48,13 @@ void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request, done(s); } +void Worker::DeleteWorkerSessionAsync(const DeleteWorkerSessionRequest* request, + DeleteWorkerSessionResponse* response, + StatusCallback done) { + Status s = env_->session_mgr->DeleteSession(request->session_handle()); + done(s); +} + void Worker::RegisterGraphAsync(const RegisterGraphRequest* request, RegisterGraphResponse* response, StatusCallback done) { diff --git a/tensorflow/core/distributed_runtime/worker.h b/tensorflow/core/distributed_runtime/worker.h index 07300338c38..c62347926fa 100644 --- a/tensorflow/core/distributed_runtime/worker.h +++ b/tensorflow/core/distributed_runtime/worker.h @@ -52,6 +52,10 @@ class Worker : public WorkerInterface { CreateWorkerSessionResponse* response, StatusCallback done) override; + void DeleteWorkerSessionAsync(const DeleteWorkerSessionRequest* request, + DeleteWorkerSessionResponse* response, + StatusCallback done) override; + void RegisterGraphAsync(const RegisterGraphRequest* request, RegisterGraphResponse* response, StatusCallback done) override; diff --git a/tensorflow/core/distributed_runtime/worker_interface.h b/tensorflow/core/distributed_runtime/worker_interface.h index c9db28ec67f..4c58bf41a46 100644 --- a/tensorflow/core/distributed_runtime/worker_interface.h +++ b/tensorflow/core/distributed_runtime/worker_interface.h @@ -44,6 +44,10 @@ class WorkerInterface { const CreateWorkerSessionRequest* request, CreateWorkerSessionResponse* response, StatusCallback done) = 0; + virtual void DeleteWorkerSessionAsync( + const DeleteWorkerSessionRequest* request, + DeleteWorkerSessionResponse* response, StatusCallback done) = 0; + virtual void RegisterGraphAsync(const RegisterGraphRequest* request, RegisterGraphResponse* response, StatusCallback done) = 0; @@ -118,6 +122,11 @@ class WorkerInterface { return CallAndWait(&ME::CreateWorkerSessionAsync, request, response); } + Status DeleteWorkerSession(const DeleteWorkerSessionRequest* request, + DeleteWorkerSessionResponse* response) { + return CallAndWait(&ME::DeleteWorkerSessionAsync, request, response); + } + Status RegisterGraph(const RegisterGraphRequest* request, RegisterGraphResponse* response) { return CallAndWait(&ME::RegisterGraphAsync, request, response); diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index f039497f13b..477184022df 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -243,6 +243,10 @@ DEFINE_GET_ATTR(Tensor, tensor, "tensor", emplace_back, t, Tensor t; DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;); #undef DEFINE_GET_ATTR +bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) { + return node_def.attr().find(attr_name.ToString()) != node_def.attr().end(); +} + static const string& kEmptyString = *new string(); const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name) { diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h index 523b5382954..f6f28aac481 100644 --- a/tensorflow/core/framework/node_def_util.h +++ b/tensorflow/core/framework/node_def_util.h @@ -157,6 +157,9 @@ class AttrSlice { const AttrValueMap* attrs_; }; +// Return true if the attr with the name attr_name is defined in node_def. +bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name); + // Look up the attr with name attr_name and set *value to its value. If no // attr with attr_name is found in node_def, or the attr does not have // a matching type, a non-ok status will be returned. diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc index 2f737a0f169..f7d4166f970 100644 --- a/tensorflow/core/framework/op_def_util.cc +++ b/tensorflow/core/framework/op_def_util.cc @@ -161,6 +161,15 @@ OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) { return nullptr; } +const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def) { + for (int i = 0; i < op_def.input_arg_size(); ++i) { + if (op_def.input_arg(i).name() == name) { + return &op_def.input_arg(i); + } + } + return nullptr; +} + #define VALIDATE(EXPR, ...) \ do { \ if (!(EXPR)) { \ diff --git a/tensorflow/core/framework/op_def_util.h b/tensorflow/core/framework/op_def_util.h index c329e4627cc..f9661dceddc 100644 --- a/tensorflow/core/framework/op_def_util.h +++ b/tensorflow/core/framework/op_def_util.h @@ -43,6 +43,10 @@ Status ValidateAttrValue(const AttrValue& attr_value, const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def); OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def); +// Searches op_def for input argument with the indicated name. +// Returns nullptr if no such attr is found. +const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def); + // Produce a human-readable version of an op_def that is more concise // than a text-format proto. Excludes descriptions. string SummarizeOpDef(const OpDef& op_def); diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index c23692409c6..4d410809e77 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -192,6 +192,10 @@ OpKernelConstruction::OpKernelConstruction( graph_def_version_(graph_def_version), status_(status) {} +bool OpKernelConstruction::HasAttr(StringPiece attr_name) const { + return HasNodeAttr(def(), attr_name); +} + void OpKernelConstruction::SetStatus(const Status& status) { status_->Update(status); } diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 7eec84e26c7..da0dc549435 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -301,6 +301,9 @@ class OpKernelConstruction { template Status GetAttr(StringPiece attr_name, T* value) const; + // Return true if the attr_name is defined in def(). + bool HasAttr(StringPiece attr_name) const; + // Return the device type. const DeviceType& device_type() const { return device_type_; } diff --git a/tensorflow/core/graph/graph_def_builder.h b/tensorflow/core/graph/graph_def_builder.h index 4d9fe1dee97..b389cd80531 100644 --- a/tensorflow/core/graph/graph_def_builder.h +++ b/tensorflow/core/graph/graph_def_builder.h @@ -165,6 +165,20 @@ class GraphDefBuilder { // by name), and makes sure the resulting graph is valid. Status ToGraph(Graph* graph) const; + // Adds the function and gradient definitions in `fdef_lib` to this graph's op + // registry. Ignores duplicate functions, and returns a bad status if an + // imported function differs from an existing function or op with the same + // name. + Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) { + return graph_.AddFunctionLibrary(fdef_lib); + } + + // Returns whether a user-defined function with `name` already exists in the + // graph. + bool HasFunction(const string& name) { + return graph_.flib_def().Find(name) != nullptr; + } + private: Graph graph_; Status status_; diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc index ead44de1e2f..e2db47b758f 100644 --- a/tensorflow/core/grappler/clusters/cluster.cc +++ b/tensorflow/core/grappler/clusters/cluster.cc @@ -57,7 +57,7 @@ void Cluster::DisableOptimizer(bool disable) { // Disable Grappler optimizations. auto rewriter_config = options_.config.mutable_graph_options()->mutable_rewrite_options(); - rewriter_config->set_optimize_tensor_layout(false); + rewriter_config->set_layout_optimizer(RewriterConfig::OFF); rewriter_config->set_disable_model_pruning(true); rewriter_config->set_constant_folding(RewriterConfig::OFF); rewriter_config->set_memory_optimization(RewriterConfig::NO_MEM_OPT); diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 576745b21a5..548a2c6f70d 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -91,8 +91,15 @@ struct Processor { *result = -counter; counter++; } else { - CHECK_LE(0, InferenceContext::Value(d)); - *result = InferenceContext::Value(d); + int64 val = InferenceContext::Value(d); + if (val >= 0) { + *result = val; + } else { + // A shape inference function generated an invalid dimension handle. + // Use a symbolic dimension to encode this. + *result = -counter; + counter++; + } } } diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 2ab3a9144c8..0bb98d37930 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -677,10 +677,10 @@ Costs VirtualScheduler::Summary() const { critical_path_costs.estimated_max_memory_per_device[name] = max_memory_usage; + const Costs::NanoSeconds wall_time_ns = state.GetCurrTime(); VLOG(1) << "Device = " << name << ", num_nodes = " << state.nodes_executed.size() - << ", execution_time = " << state.GetCurrTime().count() - << ", memory usage: " + << ", wall_time_ns = " << wall_time_ns.count() << ", memory usage: " << "persistent = " << strings::HumanReadableNumBytes(persistent_memory_usage) << ", peak = " @@ -698,9 +698,11 @@ Costs VirtualScheduler::Summary() const { op_to_memory[node->op()] += CalculateOutputSize(node_map_.at(node).output_properties, port); } + Costs::NanoSeconds total_compute_time_ns; for (const auto& op_cost_pair : state.op_to_cost) { const auto& op = op_cost_pair.first; const auto& cost = op_cost_pair.second.execution_time.count(); + total_compute_time_ns += op_cost_pair.second.execution_time; int64 op_mem_usage = 0; auto it = op_to_memory.find(op); if (it != op_to_memory.end()) { @@ -718,6 +720,15 @@ Costs VirtualScheduler::Summary() const { << (persisent_ops.count(op) > 0 ? ": persistent op)" : ")"); } } + + int utilization = 0; + if (wall_time_ns.count() > 0) { + utilization = total_compute_time_ns.count() * 100 / wall_time_ns.count(); + } + VLOG(1) << "Device = " << name + << ", total_compute_time_ns = " << total_compute_time_ns.count() + << ", utilization = " << utilization << "%"; + if (critical_path_costs.execution_time <= state.GetCurrTime()) { critical_path_costs = state.device_costs; } diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 54004a5e07f..dbfa8ae503f 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -194,6 +194,47 @@ tf_cc_test( ], ) +cc_library( + name = "dependency_optimizer", + srcs = ["dependency_optimizer.cc"], + hdrs = [ + "dependency_optimizer.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":arithmetic_optimizer", + ":constant_folding", + ":graph_optimizer", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/utils:frame", + ], +) + +tf_cc_test( + name = "dependency_optimizer_test", + size = "small", + srcs = ["dependency_optimizer_test.cc"], + deps = [ + ":constant_folding", + ":dependency_optimizer", + ":model_pruner", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + ], +) + cc_library( name = "model_pruner", srcs = ["model_pruner.cc"], @@ -311,6 +352,7 @@ cc_library( ":arithmetic_optimizer", ":auto_parallel", ":constant_folding", + ":dependency_optimizer", ":graph_optimizer", ":layout_optimizer", ":memory_optimizer", diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index c0518736fe5..c014f8898a9 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -192,7 +192,7 @@ bool SimplyReordersData(const NodeDef& node) { // Follow a chain (through input(0)) of ops starting at `source->input(0)` as // long as they // 1. preserve the values of their first input, -// 2. have a single output, +// 2. have a single (non-control) output, // 3. are not in nodes_to_preserve. // Returns the last node in the chain satisfying these properties or source // itself if a chain of length zero was found. @@ -204,20 +204,55 @@ NodeDef* GetTailOfValuePreservingChain( const NodeDef* source, const NodeMap* node_map, const std::unordered_set& nodes_to_preserve) { const NodeDef* source_parent = source; - source = node_map->GetNode(source->input(0)); - while (IsValuePreserving(*source) && - node_map->GetOutputs(source->name()).size() == 1 && - // Do not skip over preserved nodes, because folding will change - // the results of these skipped data-reordering nodes. - // TODO(jingyue): A more elegant way is to copy this chain of - // data-reordering nodes and modify only the copy. - !nodes_to_preserve.count(source->name())) { - source_parent = source; + if (!IsControlInput(source->input(0))) { source = node_map->GetNode(source->input(0)); + while (IsValuePreserving(*source) && + node_map->GetOutputs(source->name()).size() == 1 && + // Do not skip over preserved nodes, because folding will change + // the results of these skipped data-reordering nodes. + // TODO(jingyue): A more elegant way is to copy this chain of + // data-reordering nodes and modify only the copy. + !nodes_to_preserve.count(source->name())) { + source_parent = source; + if (IsControlInput(source->input(0))) { + break; + } + source = node_map->GetNode(source->input(0)); + } } return const_cast(source_parent); } +bool MaybeAddControlInput(const string& new_input, NodeDef* node, + GraphDef* graph, NodeMap* node_map) { + bool already_exists = false; + for (const string& input : node->input()) { + if (input == new_input || AsControlDependency(input) == new_input) { + already_exists = true; + break; + } + } + if (!already_exists) { + const string ctrl_dep = + ConstantFolding::AddControlDependency(new_input, graph, node_map); + node->add_input(ctrl_dep); + node_map->AddOutput(NodeName(new_input), node->name()); + } + return !already_exists; +} + +int CopyControlInputs(const NodeDef& from, NodeDef* to, GraphDef* graph, + NodeMap* node_map) { + int num_copied = 0; + for (const string& input : from.input()) { + if (IsControlInput(input) && + MaybeAddControlInput(input, to, graph, node_map)) { + ++num_copied; + } + } + return num_copied; +} + // Returns the data type in attribute `attr_name` of `node`. If that attribute // doesn't exist, returns DT_INVALID. DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name) { @@ -481,8 +516,10 @@ bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const { return true; } -bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const { - if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) { +// static +bool ArithmeticOptimizer::CanDedup( + const NodeDef& node, const std::unordered_set& nodes_to_preserve) { + if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) { return false; } if (IsEnter(node) || IsExit(node) || IsPlaceholder(node)) { @@ -520,7 +557,7 @@ void ArithmeticOptimizer::DedupComputations(GraphDef* optimized_graph) const { continue; } NodeDef* node = optimized_graph->mutable_node(i); - if (!CanDedup(*node)) { + if (!CanDedup(*node, nodes_to_preserve_)) { continue; } NodeDef* rep = nodes.FindOrAddRepresentative(node); @@ -852,7 +889,12 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( // Mul(Const(N), x)) // bool all_equal = true; + int num_inputs = 1; for (int i = 1; i < node->input_size(); ++i) { + if (IsControlInput(node->input(i))) { + break; + } + ++num_inputs; if (node->input(i) != node->input(0)) { all_equal = false; break; @@ -860,10 +902,9 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( } if (all_equal && node_map->GetNode(node->name() + "_const") == nullptr) { // 1. Create constant node with value N. - const int N = node->input_size(); const auto type = GetDataTypeFromAttr(*node, "T"); Tensor t(type, TensorShape({})); - Status status = SetTensorValue(type, N, &t); + Status status = SetTensorValue(type, num_inputs, &t); if (!status.ok()) { LOG(WARNING) << "Failed to create const node: " << status.error_message(); @@ -889,6 +930,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( new_mul_node->add_input(node->input(0)); node_map->AddOutput(node->input(0), new_mul_node->name()); + CopyControlInputs(*node, new_mul_node, graph_def, node_map); AddFrameControlDeps(node, {new_const_node, new_mul_node}, node->input(0), {new_const_node}, graph_def, node_map, frame_map); return new_mul_node->name(); @@ -900,11 +942,12 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( // where all the inputs are Mul nodes. This pattern occurs frequently in // regularization terms for the gradients during training. if (node->input_size() > 1 && IsAggregate(*node) && - node_map->GetNode(node->name() + "_hoist") == nullptr) { + node_map->GetNode(node->name() + "_hoist_add") == nullptr) { // Determine the set of common factors if the input nodes are all Mul nodes. std::set common_factors; int i = 0; - while (i < node->input_size() && (i == 0 || !common_factors.empty())) { + while (i < node->input_size() && (i == 0 || !common_factors.empty()) && + !IsControlInput(node->input(i))) { const NodeDef* input = node_map->GetNode(node->input(i)); if (input->op() == "Mul") { std::set factors_i{input->input(0), input->input(1)}; @@ -934,31 +977,34 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( NodeDef* new_mul_node = graph_def->add_node(); NodeDef* new_add_node = graph_def->add_node(); *new_add_node = *node; - new_add_node->set_name(node->name() + "_hoist"); + new_add_node->set_name(node->name() + "_hoist_add"); new_nodes->push_back(new_add_node); node_map->AddNode(new_add_node->name(), new_add_node); for (int i = 0; i < node->input_size(); ++i) { - NodeDef* mul_node = node_map->GetNode(node->input(i)); + const string& input = node->input(i); + if (IsControlInput(input)) { + MaybeAddControlInput(input, new_add_node, graph_def, node_map); + continue; + } + NodeDef* mul_node = node_map->GetNode(input); int unique_factor_index = mul_node->input(0) == common_factor ? 1 : 0; const string unique_factor = mul_node->input(unique_factor_index); new_add_node->set_input(i, unique_factor); // 2. Use a copy of the first Mul node for the outer multiplication. if (i == 0) { *new_mul_node = *mul_node; - new_mul_node->set_name(new_mul_node->name() + "_hoist"); + new_mul_node->set_device(node->device()); + new_mul_node->set_name(node->name() + "_hoist_mul"); new_mul_node->set_input(0, common_factor); new_mul_node->set_input(1, new_add_node->name()); node_map->AddNode(new_mul_node->name(), new_mul_node); } } - // 3. Set the device of the new nodes to that of the common factor "x". - NodeDef* common_factor_node = node_map->GetNode(common_factor); - new_add_node->set_device(common_factor_node->device()); - new_mul_node->set_device(common_factor_node->device()); - // 4. Add frame dependencies that the original node might have had. + // 3. Add frame dependencies that the original node might have had. AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor, {new_add_node}, graph_def, node_map, frame_map); + return new_mul_node->name(); } } @@ -1121,15 +1167,11 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps( << consumer->name() << " to " << simplified_tensor; } node_map.UpdateInput(consumer->name(), node->name(), simplified_tensor); - if (!nodes_to_simplify.Exists(consumer)) { - nodes_to_simplify.PushBack(consumer); - } + nodes_to_simplify.PushBack(consumer); } } for (const NodeDef* new_node : new_nodes) { - if (!nodes_to_simplify.Exists(new_node)) { - nodes_to_simplify.PushBack(new_node); - } + nodes_to_simplify.PushBack(new_node); } } return Status::OK(); @@ -1140,7 +1182,6 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/, GraphDef* optimized_graph) { *optimized_graph = item.graph; nodes_to_preserve_ = item.NodesToPreserve(); - GraphProperties graph_properties(item); TF_RETURN_IF_ERROR(graph_properties.InferStatically()); TF_RETURN_IF_ERROR(graph_properties.AnnotateOutputShapes(optimized_graph)); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 4d2e160ff48..c8cc292295c 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -28,6 +28,11 @@ namespace grappler { // run a model. class ArithmeticOptimizer : public GraphOptimizer { public: + // Returns true if it is safe to dedup node from the graph. + // TODO(rmlarsen): Refactor to op_types.{h,cc}. + static bool CanDedup(const NodeDef& node, + const std::unordered_set& nodes_to_preserve); + ArithmeticOptimizer() : opt_level_(RewriterConfig::ON) {} explicit ArithmeticOptimizer(RewriterConfig::Toggle opt_level) : opt_level_(opt_level) {} @@ -42,7 +47,6 @@ class ArithmeticOptimizer : public GraphOptimizer { const GraphDef& optimized_graph, double result) override; private: - bool CanDedup(const NodeDef& node) const; void DedupComputations(GraphDef* optimized_graph) const; // Runs peep-hole optimizations on `optimized_graph`, e.g., removing inverse // transposes. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 4fcbb0120e6..354a3069052 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -164,6 +164,37 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithChain) { EXPECT_EQ("c", output.node(2).input(0)); } +TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithControlChain) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); + Output recip1 = ops::Reciprocal(s.WithOpName("recip1"), c); + Output id1 = ops::Identity(s.WithOpName("id1"), recip1); + Output squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1); + Output recip2 = ops::Reciprocal( + s.WithOpName("recip2").WithControlDependencies(squeeze), c); + Output id2 = ops::Identity(s.WithOpName("id2"), recip2); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ArithmeticOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + // The optimizer should be a noop. + EXPECT_EQ(item.graph.node_size(), output.node_size()); + for (int i = 0; i < item.graph.node_size(); ++i) { + const NodeDef& original = item.graph.node(i); + const NodeDef& optimized = output.node(i); + EXPECT_EQ(original.name(), optimized.name()); + EXPECT_EQ(original.op(), optimized.op()); + EXPECT_EQ(original.input_size(), optimized.input_size()); + for (int j = 0; j < original.input_size(); ++j) { + EXPECT_EQ(original.input(j), optimized.input(j)); + } + } +} + TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); @@ -185,6 +216,9 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) { EXPECT_EQ(5, output.node_size()); const NodeDef& new_const = output.node(3); EXPECT_EQ("add_const", new_const.name()); + EXPECT_EQ("^x", new_const.input(0)); + EXPECT_EQ(std::string("\0\0\0@", 4), + new_const.attr().at("value").tensor().tensor_content()); const NodeDef& new_mul = output.node(4); EXPECT_EQ("add_mul", new_mul.name()); EXPECT_EQ("add_const", new_mul.input(0)); @@ -194,6 +228,41 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) { EXPECT_EQ("add_mul", new_id.input(0)); } +TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2}); + Output x = ops::Const(s.WithOpName("x"), {3.0f, 4.0f}, {1, 2}); + Output add = ops::Add(s.WithOpName("add").WithControlDependencies(y), x, x); + Output id = ops::Identity(s.WithOpName("id"), add); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ArithmeticOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + // Run the optimizer twice to make sure the rewrite is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(6, output.node_size()); + const NodeDef& new_const = output.node(4); + EXPECT_EQ("add_const", new_const.name()); + EXPECT_EQ("^x", new_const.input(0)); + EXPECT_EQ(std::string("\0\0\0@", 4), + new_const.attr().at("value").tensor().tensor_content()); + const NodeDef& new_mul = output.node(5); + EXPECT_EQ("add_mul", new_mul.name()); + EXPECT_EQ("add_const", new_mul.input(0)); + EXPECT_EQ("x", new_mul.input(1)); + EXPECT_EQ("^y", new_mul.input(2)); + const NodeDef& new_id = output.node(3); + EXPECT_EQ("id", new_id.name()); + EXPECT_EQ("add_mul", new_id.input(0)); +} + TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) { // Test case from b/69059093. tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -207,6 +276,13 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) { GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + const std::vector devices{ + "/device:CPU:0", "/device:GPU:0", "/device:CPU:0", "/device:GPU:1", + "/device:CPU:0", "/device:CPU:0", "/device:CPU:0", + }; + for (int i = 0; i < item.graph.node_size(); ++i) { + item.graph.mutable_node(i)->set_device(devices[i]); + } ArithmeticOptimizer optimizer; GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); @@ -216,36 +292,48 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) { status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(11, output.node_size()); - const NodeDef& new_id = output.node(4); - EXPECT_EQ("id", new_id.name()); - EXPECT_EQ("Add_6_mul", new_id.input(0)); - - // Add4 and add5 get deduped, and we rewrite each of the 3 remaining add nodes - // of the form Add(x,x) into Mul(Const(2), x). - const NodeDef& new_add_4_const = output.node(5); - EXPECT_EQ("Add_4_const", new_add_4_const.name()); - EXPECT_EQ("^Add", new_add_4_const.input(0)); - const NodeDef& new_add_4_mul = output.node(6); - EXPECT_EQ("Add_4_mul", new_add_4_mul.name()); - EXPECT_EQ("Add_4_const", new_add_4_mul.input(0)); - EXPECT_EQ("Add_mul", new_add_4_mul.input(1)); - - const NodeDef& new_add_6_const = output.node(7); - EXPECT_EQ("Add_6_const", new_add_6_const.name()); - EXPECT_EQ("^Add_4_mul", new_add_6_const.input(0)); - const NodeDef& new_add_6_mul = output.node(8); - EXPECT_EQ("Add_6_mul", new_add_6_mul.name()); - EXPECT_EQ("Add_6_const", new_add_6_mul.input(0)); - EXPECT_EQ("Add_4_mul", new_add_6_mul.input(1)); - - const NodeDef& new_add_const = output.node(9); - EXPECT_EQ("Add_const", new_add_const.name()); - EXPECT_EQ("^Placeholder", new_add_const.input(0)); - const NodeDef& new_add_mul = output.node(10); - EXPECT_EQ("Add_mul", new_add_mul.name()); - EXPECT_EQ("Add_const", new_add_mul.input(0)); - EXPECT_EQ("Placeholder", new_add_mul.input(1)); + EXPECT_EQ(17, output.node_size()); + // The graph gets optimized to + // Mul(p, + // Add(Add(Const(2), Const(2)), + // Add(Const(2), Const(2)))) + for (const auto& node : output.node()) { + if ("id" == node.name()) { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("Add_6_hoist_mul", node.input(0)); + } else if ("Add_6_hoist_mul" == node.name()) { + EXPECT_EQ("Mul", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("Placeholder", node.input(0)); + EXPECT_EQ("Add_6_hoist_add", node.input(1)); + } else if ("Add_6_hoist_add" == node.name()) { + EXPECT_EQ("Add", node.op()); + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("Add_4_hoist_add", node.input(0)); + EXPECT_EQ("Add_5_hoist_add", node.input(1)); + EXPECT_EQ("^Placeholder", node.input(2)); + } else if ("Add_4_hoist_add" == node.name()) { + EXPECT_EQ("Add", node.op()); + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("Add_const", node.input(0)); + EXPECT_EQ("Add_1_const", node.input(1)); + EXPECT_EQ("^Placeholder", node.input(2)); + } else if ("Add_5_hoist_add" == node.name()) { + EXPECT_EQ("Add", node.op()); + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("Add_const", node.input(0)); + EXPECT_EQ("Add_1_const", node.input(1)); + EXPECT_EQ("^Placeholder", node.input(2)); + } else if ("Add_const" == node.name()) { + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("^Placeholder", node.input(0)); + } else if ("Add_1_const" == node.name()) { + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("^Placeholder", node.input(0)); + } + } } TEST_F(ArithmeticOptimizerTest, HoistFactor) { @@ -272,16 +360,16 @@ TEST_F(ArithmeticOptimizerTest, HoistFactor) { EXPECT_EQ(9, output.node_size()); const NodeDef& new_add = output.node(8); - EXPECT_EQ("add_hoist", new_add.name()); + EXPECT_EQ("add_hoist_add", new_add.name()); EXPECT_EQ("y1", new_add.input(0)); EXPECT_EQ("y2", new_add.input(1)); const NodeDef& new_mul = output.node(7); - EXPECT_EQ("mul1_hoist", new_mul.name()); + EXPECT_EQ("add_hoist_mul", new_mul.name()); EXPECT_EQ("x", new_mul.input(0)); - EXPECT_EQ("add_hoist", new_mul.input(1)); + EXPECT_EQ("add_hoist_add", new_mul.input(1)); const NodeDef& new_id = output.node(6); EXPECT_EQ("id", new_id.name()); - EXPECT_EQ("mul1_hoist", new_id.input(0)); + EXPECT_EQ("add_hoist_mul", new_id.input(0)); } TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) { @@ -463,10 +551,6 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) { item.graph = output; TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); - for (const auto& node : output.node()) { - LOG(INFO) << node.DebugString(); - } - EXPECT_EQ(0, std::count_if( output.node().begin(), output.node().end(), [](const NodeDef& node) { return node.op() == "Reshape"; })); @@ -492,10 +576,6 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) { item.graph = output; TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); - for (const auto& node : output.node()) { - LOG(INFO) << node.DebugString(); - } - EXPECT_EQ(1, std::count_if( output.node().begin(), output.node().end(), [](const NodeDef& node) { return node.op() == "Reshape"; })); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 02a732b0923..b722905032a 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -122,7 +122,6 @@ string ConstantFolding::AddControlDependency(const string& input_name, auto outputs = node_map->GetOutputs(node->name()); for (const NodeDef* node : outputs) { if (IsIdentity(*node)) { - CHECK_EQ(1, node->input_size()); if (IsSameInput(node->input(0), input_name)) { return AsControlDependency(*node); } @@ -340,92 +339,180 @@ bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties, } } // namespace +Status ConstantFolding::MaterializeBroadcastGradientArgs( + const NodeDef& node, const GraphProperties& properties) { + const NodeDef* shape_node1 = node_map_->GetNode(node.input(0)); + const NodeDef* shape_node2 = node_map_->GetNode(node.input(1)); + if (shape_node1 == nullptr || + (shape_node1->op() != "Shape" && shape_node1->op() != "Const") || + shape_node2 == nullptr || + (shape_node2->op() != "Shape" && shape_node2->op() != "Const")) { + return Status::OK(); + } + int64 min_id = 0; + BCast::Vec shape1; + if (!ExtractShape(*shape_node1, properties, &shape1, &min_id)) { + return Status::OK(); + } + BCast::Vec shape2; + if (!ExtractShape(*shape_node2, properties, &shape2, &min_id)) { + return Status::OK(); + } + // A value of -1 means we don't known anything about the dimension. Replace + // the -1 values with unique dimension ids since we don't want two '-1' + // dimensions to be considered equal. + for (auto& id : shape1) { + if (id == -1) { + id = --min_id; + } + } + for (auto& id : shape2) { + if (id == -1) { + id = --min_id; + } + } + BCast bcast(shape1, shape2); + if (!bcast.IsValid()) { + return Status::OK(); + } + BCast::Vec reduce_dims[2]; + reduce_dims[0] = bcast.grad_x_reduce_idx(); + reduce_dims[1] = bcast.grad_y_reduce_idx(); + + const DataType type = node.attr().at("T").type(); + NodeDef* out[2]; + for (int j = 0; j < 2; ++j) { + if (!reduce_dims[j].empty()) { + // This is the case when a tensor dimension of 1 is matched against an + // unknown dimension. The unknown dimension could also be equal to 1, in + // which case there would be no reduction. + out[j] = nullptr; + } else { + string const_name = AddPrefixToNodeName( + strings::StrCat(node.name(), "-", j), kConstantFoldingConst); + out[j] = node_map_->GetNode(const_name); + if (out[j] == nullptr) { + out[j] = graph_.add_node(); + Tensor value(type, TensorShape({0})); + *out[j] = CreateNodeDef(const_name, TensorValue(&value)); + out[j]->set_device(node.device()); + node_map_->AddNode(const_name, out[j]); + string ctrl_dep = + AddControlDependency(node.name(), &graph_, node_map_.get()); + *out[j]->add_input() = ctrl_dep; + node_map_->AddOutput(NodeName(ctrl_dep), const_name); + } + } + } + + auto outputs = node_map_->GetOutputs(node.name()); + for (const auto& output : outputs) { + for (int k = 0; k < output->input_size(); ++k) { + int port; + string node_name = ParseNodeName(output->input(k), &port); + if (node_name == node.name() && port >= 0 && port < 2 && out[port]) { + *output->mutable_input(k) = out[port]->name(); + node_map_->UpdateInput(output->name(), node_name, out[port]->name()); + } + } + } + + return Status::OK(); +} + +Status ConstantFolding::MaterializeReductionIndices( + NodeDef* node, const GraphProperties& properties) { + if (node->input_size() < 2) { + return Status::OK(); + } + const NodeDef* indices = node_map_->GetNode(node->input(1)); + if (!indices || IsConstant(*indices)) { + // The reduction indices are already constant, there's nothing to do. + return Status::OK(); + } + + const OpInfo::TensorProperties& input_prop = + properties.GetInputProperties(node->name())[0]; + if (input_prop.shape().unknown_rank()) { + // We can't do anything if we don't know the rank of the input. + return Status::OK(); + } + const int rank = input_prop.shape().dim_size(); + if (rank == 0) { + // Unexpected graph, don't try to change it. + return Status::OK(); + } + const OpInfo::TensorProperties& output_prop = + properties.GetOutputProperties(node->name())[0]; + PartialTensorShape output_shape(output_prop.shape()); + if (output_shape.num_elements() != 1) { + bool full_reduction = false; + for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) { + if (!IsReshape(*fanout)) { + continue; + } + const OpInfo::TensorProperties& reshape_prop = + properties.GetOutputProperties(fanout->name())[0]; + PartialTensorShape shape(reshape_prop.shape()); + if (shape.num_elements() != 1) { + return Status::OK(); + } else { + full_reduction = true; + } + } + if (!full_reduction) { + return Status::OK(); + } + } + + const OpInfo::TensorProperties& reduction_prop = + properties.GetInputProperties(node->name())[1]; + DataType dtype = reduction_prop.dtype(); + if (dtype != DT_INT32 && dtype != DT_INT64) { + return Status::OK(); + } + // We know it's a full reduction. We can generate the set of indices to + // reduce. + string const_name = + AddPrefixToNodeName(strings::StrCat(node->name(), "-reduction_indices"), + kConstantFoldingConst); + if (node_map_->GetNode(const_name)) { + return Status::OK(); + } + NodeDef* reduction_indices = graph_.add_node(); + Tensor value(dtype, TensorShape({rank})); + for (int i = 0; i < rank; ++i) { + if (dtype == DT_INT32) { + value.vec()(i) = i; + } else { + value.vec()(i) = i; + } + } + *reduction_indices = CreateNodeDef(const_name, TensorValue(&value)); + reduction_indices->set_device(node->device()); + *reduction_indices->add_input() = + AddControlDependency(node->input(1), &graph_, node_map_.get()); + node_map_->AddNode(const_name, reduction_indices); + + node->set_input(1, reduction_indices->name()); + node_map_->UpdateInput(node->name(), indices->name(), + reduction_indices->name()); + + return Status::OK(); +} + Status ConstantFolding::MaterializeConstants( const GrapplerItem& item, const GraphProperties& properties) { const int node_count = graph_.node_size(); for (int i = 0; i < node_count; ++i) { NodeDef& node = *graph_.mutable_node(i); const string& op = node.op(); - if (op != "BroadcastGradientArgs") { - continue; - } - const NodeDef* shape_node1 = node_map_->GetNode(node.input(0)); - const NodeDef* shape_node2 = node_map_->GetNode(node.input(1)); - if (shape_node1 == nullptr || - (shape_node1->op() != "Shape" && shape_node1->op() != "Const") || - shape_node2 == nullptr || - (shape_node2->op() != "Shape" && shape_node2->op() != "Const")) { - continue; - } - int64 min_id = 0; - BCast::Vec shape1; - if (!ExtractShape(*shape_node1, properties, &shape1, &min_id)) { - continue; - } - BCast::Vec shape2; - if (!ExtractShape(*shape_node2, properties, &shape2, &min_id)) { - continue; - } - // A value of -1 means we don't known anything about the dimension. Replace - // the -1 values with unique dimension ids since we don't want two '-1' - // dimensions to be considered equal. - for (auto& id : shape1) { - if (id == -1) { - id = --min_id; - } - } - for (auto& id : shape2) { - if (id == -1) { - id = --min_id; - } - } - BCast bcast(shape1, shape2); - if (!bcast.IsValid()) { - continue; - } - BCast::Vec reduce_dims[2]; - reduce_dims[0] = bcast.grad_x_reduce_idx(); - reduce_dims[1] = bcast.grad_y_reduce_idx(); - - const DataType type = node.attr().at("T").type(); - NodeDef* out[2]; - for (int j = 0; j < 2; ++j) { - if (!reduce_dims[j].empty()) { - // This is the case when a tensor dimension 1 is matched against an - // unknown dimension. The unknown dimension could also be equal to 1, in - // which case there would be no reduction. - out[j] = nullptr; - } else { - Tensor value(type, TensorShape({0})); - string const_name = AddPrefixToNodeName( - strings::StrCat(node.name(), "-", j), kConstantFoldingConst); - out[j] = node_map_->GetNode(const_name); - if (!out[j]) { - out[j] = graph_.add_node(); - *out[j] = CreateNodeDef(const_name, TensorValue(&value)); - out[j]->set_device(node.device()); - node_map_->AddNode(const_name, out[j]); - string ctrl_dep = - AddControlDependency(node.name(), &graph_, node_map_.get()); - *out[j]->add_input() = ctrl_dep; - node_map_->AddOutput(NodeName(ctrl_dep), const_name); - } - } - } - - auto outputs = node_map_->GetOutputs(node.name()); - for (const auto& output : outputs) { - for (int k = 0; k < output->input_size(); ++k) { - int port; - string node_name = ParseNodeName(output->input(k), &port); - if (node_name == node.name() && port >= 0 && port < 2 && out[port]) { - *output->mutable_input(k) = out[port]->name(); - node_map_->UpdateInput(output->name(), node_name, out[port]->name()); - } - } + if (op == "BroadcastGradientArgs") { + TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties)); + } else if (IsReduction(node)) { + TF_RETURN_IF_ERROR(MaterializeReductionIndices(&node, properties)); } } - return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index dd988f336cb..f04f413c10a 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -53,6 +53,12 @@ class ConstantFolding : public GraphOptimizer { private: Status MaterializeShapes(const GrapplerItem& item, const GraphProperties& properties); + + Status MaterializeBroadcastGradientArgs(const NodeDef& node, + const GraphProperties& properties); + Status MaterializeReductionIndices(NodeDef* node, + const GraphProperties& properties); + Status MaterializeConstants(const GrapplerItem& item, const GraphProperties& properties); bool IsFoldable(const NodeDef& node) const; diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 43f84b1ddfd..428376c02cc 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -840,7 +840,7 @@ TEST_F(ConstantFoldingTest, Packing) { EXPECT_GT(8000, output.ByteSizeLong()); } -TEST_F(ConstantFoldingTest, ConstantMaterialization) { +TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT, @@ -918,6 +918,45 @@ TEST_F(ConstantFoldingTest, ConstantMaterialization) { EXPECT_EQ(7, found); } +TEST_F(ConstantFoldingTest, MaterializeReductionIndices) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input = + ops::Placeholder(s.WithOpName("input"), DT_FLOAT, + ops::Placeholder::Shape(PartialTensorShape({-1, -1}))); + Output indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32); + Output sum = ops::Sum(s.WithOpName("sum"), input, indices); + Output size = ops::Const(s.WithOpName("size"), 1, {1}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ConstantFolding fold(RewriterConfig::AGGRESSIVE, nullptr /* cpu_device */); + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + // Run a second time to make sure the optimization is idempotent. + item.graph.Swap(&output); + status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + int found = 0; + for (const auto& node : output.node()) { + if (node.name() == "ConstantFolding/sum-reduction_indices") { + ++found; + EXPECT_EQ("Const", node.op()); + EXPECT_EQ("^indices", node.input(0)); + EXPECT_EQ(2, TensorShape(node.attr().at("value").tensor().tensor_shape()) + .num_elements()); + } else if (node.name() == "sum") { + ++found; + EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1)); + } + } + EXPECT_EQ(2, found); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc new file mode 100644 index 00000000000..49eb29d0371 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc @@ -0,0 +1,278 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" + +#include + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h" +#include "tensorflow/core/grappler/optimizers/constant_folding.h" +#include "tensorflow/core/grappler/utils/frame.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace grappler { + +namespace { +// A vector with a set. The set stores the same elements as the vector, and +// quickly answers whether a value is in the vector. Duplicated elements are not +// allowed for now. +template +class SetVector { + public: + // Returns false if value already existed in the set, true otherwise. + bool PushBack(const T& value) { + if (!set_.insert(value).second) { + return false; + } + vector_.push_back(value); + return true; + } + + T PopBack() { + T back = vector_.back(); + set_.erase(back); + vector_.pop_back(); + return back; + } + + bool Exists(const T& value) const { return set_.count(value); } + + bool Empty() const { return vector_.empty(); } + + void Reserve(int64 size) { vector_.reserve(size); } + + private: + std::unordered_set set_; + std::vector vector_; +}; + +bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map) { + for (const NodeDef* output : node_map.GetOutputs(node.name())) { + for (const string& input : output->input()) { + if (input == node.name()) { + return true; + } + } + } + return false; +} + +int FindInputSlot(const NodeDef& node, const string& input) { + for (int i = 0; i < node.input_size(); ++i) { + if (node.input(i) == input) { + return i; + } + } + return -1; +} + +} // namespace + +bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) { + if (!has_fetch_ || HasRegularOutputs(node, *node_map_)) { + return false; + } + + if (IsMerge(node)) { + return false; + } + if (!ArithmeticOptimizer::CanDedup(node, nodes_to_preserve_)) { + return false; + } + + const OpDef* op_def = nullptr; + Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); + if (!status.ok() || op_def->output_arg_size() == 0) { + return false; + } + + // TODO(rmlarsen): We have to skip Const nodes to make + // core/debug/debug_gateway_test pass. See if we can fix that test. + // TODO(rmlarsen): We have to skip Identity nodes to make an obsolete test in + // python/training/session_manager_test.py pass. See if we can fix or get rid + // of that test. + const std::unordered_set do_not_rewrite_ops = { + "Assert", "CheckNumerics", "Const", "Identity", "_Retval", + "_Arg", "_ParallelConcatUpdate", "_TPUExecute"}; + return do_not_rewrite_ops.find(node.op()) == do_not_rewrite_ops.end(); +} + +string DependencyOptimizer::TryOptimizeDependencies( + NodeDef* node, GraphDef* graph, std::vector* new_nodes) { + // Change ops that only have control dependencies as outputs to NoOps. + if (node->op() != "NoOp" && SafeToConvertToNoOp(*node)) { + VLOG(2) << "***** Replacing " << node->name() << " (" << node->op() + << ") with NoOp."; + // The outputs of this node are not consumed. Replace its inputs with + // control dependencies and replace the op itself with the NoOp op. + for (int i = 0; i < node->input_size(); ++i) { + const string& old_input = node->input(i); + if (IsControlInput(old_input)) { + continue; + } + const string ctrl_input = ConstantFolding::AddControlDependency( + old_input, graph, node_map_.get()); + node->set_input(i, ctrl_input); + node_map_->UpdateInput(node->name(), old_input, ctrl_input); + new_nodes->push_back(node_map_->GetNode(old_input)); + } + node->set_op("NoOp"); + node->clear_attr(); + new_nodes->push_back(node); + return ""; + } + + // Remove NoOp nodes if their fan-in or fan-out is less than 2. + // The non-trivial rewrites take the following form: + // + // Case a) + // x --^> +------+ x --^> +---+ + // y --^> | NoOp | --^> a ==> y --^> | a | + // ... | | ... | | + // z --^> +------+ z --^> +---+ + // + // Case b) + // +------+ --^> a +---+ --^> a + // x --^> | NoOp | --^> b ==> | x | --^> b + // | | ... | | ... + // +------+ --^> c +---+ --^> c + if (node->op() == "NoOp" && + nodes_to_preserve_.find(node->name()) == nodes_to_preserve_.end()) { + auto outputs = node_map_->GetOutputs(node->name()); + const int num_outputs = outputs.size(); + const int num_inputs = node->input_size(); + if (num_inputs > 1 && num_outputs > 1) { + return ""; + } + + for (auto consumer : outputs) { + for (int i = 0; i < num_inputs; ++i) { + const string& input = node->input(i); + // Forward dependencies from inputs to consumer if it doesn't already + // depend on it. + if (node_map_->GetOutputs(input).count(consumer) == 0) { + consumer->add_input(ConstantFolding::AddControlDependency( + input, graph, node_map_.get())); + node_map_->AddOutput(NodeName(input), consumer->name()); + } + new_nodes->push_back(node_map_->GetNode(input)); + } + // Remove dependency on node from consumer. + int pos = FindInputSlot(*consumer, AsControlDependency(node->name())); + if (pos >= 0) { + consumer->mutable_input()->SwapElements(pos, + consumer->input_size() - 1); + consumer->mutable_input()->RemoveLast(); + node_map_->RemoveOutput(node->name(), consumer->name()); + new_nodes->push_back(consumer); + } + } + + // Clear all control inputs to node. + node_map_->RemoveInputs(node->name()); + node->clear_input(); + return ""; + } + + return ""; +} + +Status DependencyOptimizer::OptimizeDependencies(GraphDef* optimized_graph) { + // TODO(rmlarsen,bsteiner): The folloing code is similar to the control loop + // in the ArithmeticOptimizer. Dedup this. + SetVector nodes_to_simplify; + for (int i = 0; i < optimized_graph->node_size(); ++i) { + const NodeDef& node = optimized_graph->node(i); + if (node.op() == "NoOp" || SafeToConvertToNoOp(node)) { + nodes_to_simplify.PushBack(optimized_graph->mutable_node()->Mutable(i)); + } + } + while (!nodes_to_simplify.Empty()) { + NodeDef* node = nodes_to_simplify.PopBack(); + std::vector new_nodes; + const string simplified_tensor = + TryOptimizeDependencies(node, optimized_graph, &new_nodes); + if (simplified_tensor.empty()) { + continue; + } + if (NodeName(simplified_tensor) != node->name()) { + // Always consider simplified_tensor for further optimizations. + NodeDef* simplified_node = node_map_->GetNode(simplified_tensor); + if (simplified_node != nullptr) { + nodes_to_simplify.PushBack(simplified_node); + } + // When `node` is simplifed to another node rather than in-place, the + // consumers of `node` are already redirected to `simplified_tensor`. + // Re-push the consumers into `nodes_to_simplify` for further + // optimizations. + std::set consumers = node_map_->GetOutputs(node->name()); + for (NodeDef* consumer : consumers) { + // Update `consumer`'s use of `node` to `input`'s operand. + for (int i = 0; i < consumer->input_size(); ++i) { + int operand_pos; + string operand_node_name = + ParseNodeName(consumer->input(i), &operand_pos); + if (operand_node_name == node->name()) { + *consumer->mutable_input(i) = + (operand_pos < 0 + ? AsControlDependency(NodeName(simplified_tensor)) + : simplified_tensor); + } + VLOG(2) << "Update input " << consumer->input(i) << " of " + << consumer->name() << " to " << simplified_tensor; + } + node_map_->UpdateInput(consumer->name(), node->name(), + simplified_tensor); + nodes_to_simplify.PushBack(consumer); + } + } + for (auto new_node : new_nodes) { + nodes_to_simplify.PushBack(new_node); + } + } + return Status::OK(); +} + +Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) { + *optimized_graph = item.graph; + nodes_to_preserve_ = item.NodesToPreserve(); + node_map_.reset(new NodeMap(optimized_graph)); + has_fetch_ = !item.fetch.empty(); + VLOG(2) << "Graph before optimization:\n" << optimized_graph->DebugString(); + TF_RETURN_IF_ERROR(OptimizeDependencies(optimized_graph)); + VLOG(2) << "Graph after optimization:\n" << optimized_graph->DebugString(); + + return Status::OK(); +} + +void DependencyOptimizer::Feedback(Cluster* /*cluster*/, + const GrapplerItem& /*item*/, + const GraphDef& /*optimized_graph*/, + double /*result*/) { + // Nothing to do for DependencyOptimizer. +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.h b/tensorflow/core/grappler/optimizers/dependency_optimizer.h new file mode 100644 index 00000000000..13ece87aff3 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.h @@ -0,0 +1,68 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_ + +#include +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +// Optimize TF computations by removing control dependencies or re-arranging +// them to shorten the critical path for a model step or enable other +// optimizations, such as removing nodes that are effectively noops. +class DependencyOptimizer : public GraphOptimizer { + public: + DependencyOptimizer() : opt_level_(RewriterConfig::ON) {} + explicit DependencyOptimizer(RewriterConfig::Toggle opt_level) + : opt_level_(opt_level) {} + ~DependencyOptimizer() override {} + + string name() const override { return "dependency_optimizer"; }; + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override; + + private: + // Returns true if it is safe to convert node to NoOp. + bool SafeToConvertToNoOp(const NodeDef& node); + + Status OptimizeDependencies(GraphDef* optimized_graph); + // Tries to simplify the expression that roots at `node` and replaces the uses + // of `node` to the simplified expression. Returns the name of the simplified + // tensor (e.g. "split:1") or an empty string if no simplification is + // performed. + string TryOptimizeDependencies(NodeDef* node, GraphDef* graph, + std::vector* new_nodes); + + bool HasOnlyControlOutputs(const NodeDef* node); + + bool has_fetch_; + RewriterConfig::Toggle opt_level_; + std::unordered_set nodes_to_preserve_; + std::unique_ptr node_map_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_ diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc new file mode 100644 index 00000000000..d54d7b2093e --- /dev/null +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc @@ -0,0 +1,201 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" +#include "tensorflow/core/grappler/optimizers/constant_folding.h" +#include "tensorflow/core/grappler/optimizers/model_pruner.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class DependencyOptimizerTest : public ::testing::Test {}; + +void VerifyGraphsEqual(const GraphDef& original_graph, + const GraphDef& optimized_graph, const string& func) { + EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func; + for (int i = 0; i < original_graph.node_size(); ++i) { + const NodeDef& original = original_graph.node(i); + const NodeDef& optimized = optimized_graph.node(i); + EXPECT_EQ(original.name(), optimized.name()) << func; + EXPECT_EQ(original.op(), optimized.op()) << func; + EXPECT_EQ(original.input_size(), optimized.input_size()) << func; + for (int j = 0; j < original.input_size(); ++j) { + EXPECT_EQ(original.input(j), optimized.input(j)) << func; + } + } +} + +TEST_F(DependencyOptimizerTest, NoOp) { + // This trivial graph is so basic there's nothing to optimize. + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + DependencyOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + VerifyGraphsEqual(item.graph, output, __FUNCTION__); +} + +TEST_F(DependencyOptimizerTest, ChangeToNoop) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2}); + Output add = ops::Add(s.WithOpName("add"), x, y); + Output id1 = + ops::Identity(s.WithOpName("id1").WithControlDependencies(add), x); + Output id2 = + ops::Identity(s.WithOpName("id2").WithControlDependencies(add), y); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch.push_back("id1"); + item.fetch.push_back("id2"); + + DependencyOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + // Run the optimizer twice to make sure the rewrite is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(item.graph.node_size(), output.node_size()); + for (int i = 0; i < item.graph.node_size(); ++i) { + const NodeDef& original = item.graph.node(i); + const NodeDef& optimized = output.node(i); + EXPECT_EQ(original.name(), optimized.name()); + if (original.name() == "add") { + EXPECT_EQ("NoOp", optimized.op()); + } else { + EXPECT_EQ(original.op(), optimized.op()); + } + EXPECT_EQ(original.input_size(), optimized.input_size()); + for (int j = 0; j < original.input_size(); ++j) { + if (original.name() == "add") { + EXPECT_EQ(AsControlDependency(original.input(j)), optimized.input(j)); + } else { + EXPECT_EQ(original.input(j), optimized.input(j)); + } + } + } +} + +TEST_F(DependencyOptimizerTest, ChangeToNoop_NoFetch) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2}); + Output add = ops::Add(s.WithOpName("add"), x, y); + Output id1 = + ops::Identity(s.WithOpName("id1").WithControlDependencies(add), x); + Output id2 = + ops::Identity(s.WithOpName("id2").WithControlDependencies(add), y); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + DependencyOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + VerifyGraphsEqual(item.graph, output, __FUNCTION__); +} + +TEST_F(DependencyOptimizerTest, RemoveNoOps_EmptyInputOrOutput) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Const(s, {1.0f, 2.0f}, {1, 2}); + auto noop1 = ops::NoOp(s); + auto noop2 = ops::NoOp(s.WithControlDependencies(x)); + Output id = ops::Identity(s.WithControlDependencies({noop1.operation}), x); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch.push_back("Identity"); + + DependencyOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + // Run the optimizer twice to make sure the rewrite is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(item.graph.node_size(), output.node_size()); + for (const NodeDef& node : output.node()) { + if (node.name() == "NoOp" || node.name() == "NoOp_1") { + EXPECT_EQ(0, node.input_size()); + } else if (node.name() == "Identity") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("Const", node.input(0)); + } + } +} + +TEST_F(DependencyOptimizerTest, RemoveNoOps_SingleInputOrOutput) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2}); + // NoOp with a single input- and two output dependencies. + auto noop = ops::NoOp(s.WithControlDependencies(x)); + // NoOp with a two input- and a single output dependency. + auto noop_1 = + ops::NoOp(s.WithControlDependencies(x).WithControlDependencies(y)); + Output id = ops::Identity(s.WithControlDependencies({noop.operation}), x); + Output id_1 = ops::Identity( + s.WithControlDependencies({noop.operation, noop_1.operation}), y); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch.push_back("Identity"); + item.fetch.push_back("Identity_1"); + + DependencyOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + // Run the optimizer twice to make sure the rewrite is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(item.graph.node_size(), output.node_size()); + for (const NodeDef& node : output.node()) { + if (node.name() == "NoOp" || node.name() == "NoOp_1") { + EXPECT_EQ(0, node.input_size()); + } else if (node.name() == "Identity") { + EXPECT_EQ("x", node.input(0)); + } else if (node.name() == "Identity_1") { + EXPECT_EQ("y", node.input(0)); + EXPECT_EQ("^x", node.input(1)); + } + } +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 6204a81f805..1fa639ad33d 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h" #include "tensorflow/core/grappler/optimizers/auto_parallel.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" +#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/grappler/optimizers/layout_optimizer.h" #include "tensorflow/core/grappler/optimizers/memory_optimizer.h" @@ -53,6 +54,10 @@ std::unique_ptr MetaOptimizer::NewOptimizer( graph_optimizer.reset( new AutoParallel(cfg_.auto_parallel().num_replicas())); } + if (optimizer == "dependency") { + graph_optimizer.reset( + new DependencyOptimizer(cfg_.dependency_optimization())); + } return graph_optimizer; } @@ -71,7 +76,11 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, optimizers.push_back(std::unique_ptr( new ArithmeticOptimizer(cfg_.arithmetic_optimization()))); } - if (cfg_.optimize_tensor_layout()) { + if (cfg_.dependency_optimization() == RewriterConfig::ON) { + optimizers.push_back(std::unique_ptr( + new DependencyOptimizer(cfg_.dependency_optimization()))); + } + if (cfg_.layout_optimizer() == RewriterConfig::ON) { optimizers.push_back( std::unique_ptr(new LayoutOptimizer())); } @@ -92,9 +101,9 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, new AutoParallel(cfg_.auto_parallel().num_replicas()))); } } else { - std::set available_optimizers = {"pruning", "constfold", - "layout", "memory", - "autoparallel", "arithmetic"}; + std::set available_optimizers = { + "pruning", "constfold", "layout", "memory", + "autoparallel", "arithmetic", "dependency"}; for (const auto& optimizer : cfg_.optimizers()) { if (available_optimizers.find(optimizer) != available_optimizers.end()) { optimizers.push_back(NewOptimizer(optimizer)); @@ -175,8 +184,10 @@ void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item, } bool MetaOptimizerEnabled(const RewriterConfig& cfg) { - return !cfg.disable_model_pruning() || cfg.optimize_tensor_layout() || + return !cfg.disable_model_pruning() || + cfg.layout_optimizer() == RewriterConfig::ON || cfg.constant_folding() != RewriterConfig::OFF || + cfg.dependency_optimization() == RewriterConfig::ON || cfg.arithmetic_optimization() != RewriterConfig::OFF || cfg.auto_parallel().enable() || cfg.memory_optimization() > 1 || !cfg.optimizers().empty(); diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index 3a5028cfe3a..9452cfbf557 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -221,8 +221,11 @@ string AsControlDependency(const NodeDef& node) { return strings::StrCat("^", node.name()); } -string AsControlDependency(const string& node) { - return strings::StrCat("^", node); +string AsControlDependency(const string& node_name) { + CHECK(!node_name.empty()); + return (!node_name.empty() && node_name[0] == '^') + ? node_name + : strings::StrCat("^", node_name); } int NumOutputs(const NodeDef& node) { diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index 3193b3ec4a6..9d747fe7dc4 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -181,6 +181,14 @@ TEST_F(UtilsTest, NumOutputs) { EXPECT_EQ(1, NumOutputs(CreateDequeueNode())); } +TEST(AsControlDependency, BasicTest) { + NodeDef node; + node.set_name("foo"); + EXPECT_EQ("^foo", AsControlDependency(node)); + EXPECT_EQ("^foo", AsControlDependency(node.name())); + EXPECT_EQ("^foo", AsControlDependency("^foo")); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 52ea2ad4806..cf95c6781a4 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -6270,6 +6270,7 @@ tf_kernel_library( "//tensorflow/contrib/tensorboard/db:summary_db_writer", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:summary_ops_op_lib", "//tensorflow/core/lib/db:sqlite", ], diff --git a/tensorflow/core/kernels/batch_dataset_op.cc b/tensorflow/core/kernels/batch_dataset_op.cc index 6a5fd17a9e6..46412a554b3 100644 --- a/tensorflow/core/kernels/batch_dataset_op.cc +++ b/tensorflow/core/kernels/batch_dataset_op.cc @@ -80,10 +80,10 @@ class BatchDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddParentDataset(input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); Node* batch_size = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size)); TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/kernels/captured_function.h b/tensorflow/core/kernels/captured_function.h index 55d337d7075..9430127600a 100644 --- a/tensorflow/core/kernels/captured_function.h +++ b/tensorflow/core/kernels/captured_function.h @@ -71,6 +71,8 @@ class CapturedFunction { ResourceMgr* resource_manager() const { return device_->resource_manager(); } + const std::vector& captured_inputs() { return captured_inputs_; } + static int64 generate_step_id() { // Choose a step ID that is guaranteed not to clash with any // Session-generated step ID. DirectSession only generates diff --git a/tensorflow/core/kernels/check_numerics_op.cc b/tensorflow/core/kernels/check_numerics_op.cc index 56cb50d2d18..534527c6bdc 100644 --- a/tensorflow/core/kernels/check_numerics_op.cc +++ b/tensorflow/core/kernels/check_numerics_op.cc @@ -168,10 +168,10 @@ class CheckNumericsOp : public AsyncOpKernel { abnormal_detected_host, context, done]() { ::perftools::gputools::cuda::ScopedActivateExecutorContext scoped_activation{stream->parent()}; - auto abnormal_detected_host_flat = abnormal_detected_host.flat(); int is_nan = abnormal_detected_host_flat(0); int is_inf = abnormal_detected_host_flat(1); + abnormal_detected_ref.Unref(); if (is_nan || is_inf) { string status; LOG(ERROR) << "abnormal_detected_host @" diff --git a/tensorflow/core/kernels/concatenate_dataset_op.cc b/tensorflow/core/kernels/concatenate_dataset_op.cc index c3bd89c479f..ad78ba01869 100644 --- a/tensorflow/core/kernels/concatenate_dataset_op.cc +++ b/tensorflow/core/kernels/concatenate_dataset_op.cc @@ -79,13 +79,13 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel { string DebugString() override { return "ConcatenateDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph = nullptr; - TF_RETURN_IF_ERROR(b->AddParentDataset(input_, &input_graph)); + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph)); Node* to_concatenate_graph = nullptr; TF_RETURN_IF_ERROR( - b->AddParentDataset(to_concatenate_, &to_concatenate_graph)); + b->AddParentDataset(ctx, to_concatenate_, &to_concatenate_graph)); TF_RETURN_IF_ERROR( b->AddDataset(this, {input_graph, to_concatenate_graph}, output)); return Status::OK(); diff --git a/tensorflow/core/kernels/dataset.h b/tensorflow/core/kernels/dataset.h index a90590fc7e0..b9b0e5a7c6c 100644 --- a/tensorflow/core/kernels/dataset.h +++ b/tensorflow/core/kernels/dataset.h @@ -137,6 +137,23 @@ class GraphDefBuilderWrapper { const std::vector& inputs, const std::vector>& attrs, Node** output) { + std::vector> enumerated_inputs( + inputs.size()); + for (int i = 0; i < inputs.size(); i++) { + enumerated_inputs[i] = std::make_pair(i, inputs[i]); + } + return AddDataset(dataset, enumerated_inputs, {}, attrs, output); + } + + template + Status AddDataset( + const DatasetType* dataset, + const std::vector>& inputs, + const std::vector< + std::pair>>& + list_inputs, + const std::vector>& attrs, + Node** output) { const string& op_type_name = dataset->op_name(); std::unique_ptr opts( new GraphDefBuilder::Options(b_->opts())); @@ -161,8 +178,22 @@ class GraphDefBuilderWrapper { } NodeBuilder node_builder(opts->GetNameForOp(op_type_name), op_type_name, opts->op_registry()); - for (auto node_out : inputs) { - node_builder.Input(node_out); + { + size_t total_size = inputs.size() + list_inputs.size(); + auto inputs_iter = inputs.begin(); + auto list_inputs_iter = list_inputs.begin(); + for (int i = 0; i < total_size; i++) { + if (inputs_iter != inputs.end() && inputs_iter->first == i) { + node_builder.Input(inputs_iter->second); + inputs_iter++; + } else if (list_inputs_iter != list_inputs.end() && + list_inputs_iter->first == i) { + node_builder.Input(list_inputs_iter->second); + list_inputs_iter++; + } else { + return errors::InvalidArgument("No input found for index ", i); + } + } } *output = opts->FinalizeBuilder(&node_builder); if (*output == nullptr) { @@ -172,35 +203,56 @@ class GraphDefBuilderWrapper { return Status::OK(); } - // TODO(shivaniagrawal): Single method for AddDataset for - // NodeOut/ArrraySlice - template - Status AddDatasetWithInputAsList(const DatasetType* dataset, - gtl::ArraySlice input, - Node** output) { - const string& op_type_name = dataset->op_name(); - std::unique_ptr opts( - new GraphDefBuilder::Options(b_->opts())); - bool has_output_types_attr = HasAttr(op_type_name, "output_types"); - bool has_output_shapes_attr = HasAttr(op_type_name, "output_shapes"); - if (has_output_shapes_attr) { - opts.reset(new GraphDefBuilder::Options( - opts->WithAttr("output_shapes", dataset->output_shapes()))); + // Adds a user-defined function with name `function_name` to the graph and + // recursively adds all functions it references. If a function with a matching + // name has already been added, returns with OK status. If a user-defined with + // name `function_name` is not found in the FunctionLibraryDefinition, returns + // and InvalidArgumentError. If the function with name `function_name` or any + // of its dependent functions are stateful, returns an InvalidArgument error. + Status AddFunction(OpKernelContext* ctx, const string& function_name) { + if (b_->HasFunction(function_name)) { + LOG(INFO) << "Function with name " << function_name << "already exists in" + << " the graph. It will not be added again."; + return Status::OK(); } - if (has_output_types_attr) { - opts.reset(new GraphDefBuilder::Options( - opts->WithAttr("output_types", dataset->output_dtypes()))); + TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(ctx, function_name)); + const FunctionLibraryDefinition* flib_def = + ctx->function_library()->GetFunctionLibraryDefinition(); + const FunctionDef* f_def = flib_def->Find(function_name); + if (f_def == nullptr) { + return errors::InvalidArgument("Unable to find FunctionDef for ", + function_name, " in the registry."); } - if (opts->HaveError()) { - return errors::Internal("AddDataset: Error building Options."); + FunctionDefLibrary def; + *def.add_function() = *f_def; + const string gradient_func = flib_def->FindGradient(function_name); + if (!gradient_func.empty()) { + GradientDef* g_def = def.add_gradient(); + g_def->set_function_name(function_name); + g_def->set_gradient_func(gradient_func); } - NodeBuilder node_builder(opts->GetNameForOp(op_type_name), op_type_name, - opts->op_registry()); - node_builder.Input(input); - *output = opts->FinalizeBuilder(&node_builder); - if (*output == nullptr) { - return errors::Internal("AddDataset: Failed to build ", op_type_name, - " op."); + TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def)); + + // Recursively add functions in inputs of function_name. + for (const NodeDef& node_def : f_def->node_def()) { + const OpRegistrationData* op_reg_data = nullptr; + TF_RETURN_IF_ERROR(flib_def->LookUp(node_def.op(), &op_reg_data)); + if (op_reg_data->is_function_op) { + TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name())); + } + } + + // Recursively add functions in attrs of function_name. + for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); + iter++) { + const AttrValue& attr_value = iter->second; + if (attr_value.has_func()) { + TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name())); + } else if (attr_value.has_list()) { + for (const NameAttrList& name_attr_list : attr_value.list().func()) { + TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name())); + } + } } return Status::OK(); } @@ -217,6 +269,28 @@ class GraphDefBuilderWrapper { b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val)); } + Status EnsureFunctionIsStateless(OpKernelContext* ctx, + const string& function_name) const { + const FunctionLibraryDefinition* lib_def = + ctx->function_library()->GetFunctionLibraryDefinition(); + const FunctionDef* function_def = lib_def->Find(function_name); + if (!function_def) { + return errors::InvalidArgument("Unable to find FunctionDef for ", + function_name, " in registry."); + } + for (const NodeDef& node_def : function_def->node_def()) { + const OpDef* op_def; + TF_RETURN_IF_ERROR(lib_def->LookUpOpDef(node_def.op(), &op_def)); + if (op_def->is_stateful()) { + return errors::InvalidArgument( + "Op[name: ", node_def.name(), ", type: ", node_def.op(), "] ", + "in function ", function_name, " is stateful. ", + "Saving stateful functions is not supported yet."); + } + } + return Status::OK(); + } + bool HasAttr(const string& op_type_name, const string& attr_name) { const OpDef* op_def = nullptr; Status s = b_->opts().op_registry()->LookUpOpDef(op_type_name, &op_def); @@ -306,7 +380,7 @@ class IteratorBase { virtual const std::vector& output_shapes() const = 0; // Saves the state of this iterator. - virtual Status Save(IteratorStateWriter* writer) { + virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) { return SaveInternal(writer); } @@ -377,7 +451,7 @@ class DatasetBase : public core::RefCounted { virtual string DebugString() = 0; // Serializes the dataset and writes it to the `writer`. - virtual Status Save(IteratorStateWriter* writer) const { + virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) const { return errors::Unimplemented("DatasetBase::Save"); } @@ -389,11 +463,18 @@ class DatasetBase : public core::RefCounted { class DatasetGraphDefBuilder : public GraphDefBuilderWrapper { public: DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {} - Status AddParentDataset(const DatasetBase* dataset, Node** output) { - return dataset->AsGraphDefInternal(this, output); + Status AddParentDataset(OpKernelContext* ctx, const DatasetBase* dataset, + Node** output) { + return dataset->AsGraphDefInternal(ctx, this, output); } }; + virtual Status AsGraphDefInternal(OpKernelContext* ctx, + DatasetGraphDefBuilder* b, + Node** node) const { + return AsGraphDefInternal(b, node); + } + virtual Status AsGraphDefInternal(DatasetGraphDefBuilder* b, Node** node) const { return errors::Unimplemented("AsGraphDefInternal"); @@ -408,10 +489,11 @@ class GraphDatasetBase : public DatasetBase { const string op_name() const { return op_name_; } - Status Save(IteratorStateWriter* writer) const override { + Status Save(OpKernelContext* ctx, + IteratorStateWriter* writer) const override { string serialized_graph_def; string output_node; - TF_RETURN_IF_ERROR(Serialize(&serialized_graph_def, &output_node)); + TF_RETURN_IF_ERROR(Serialize(ctx, &serialized_graph_def, &output_node)); TF_RETURN_IF_ERROR( writer->WriteScalar(kDatasetGraphKey, serialized_graph_def)); TF_RETURN_IF_ERROR( @@ -427,11 +509,12 @@ class GraphDatasetBase : public DatasetBase { static const char kDatasetGraphOutputNodeKey[]; private: - Status Serialize(string* serialized_graph_def, string* output_node) const { + Status Serialize(OpKernelContext* ctx, string* serialized_graph_def, + string* output_node) const { GraphDefBuilder b; DatasetGraphDefBuilder db(&b); Node* node = nullptr; - TF_RETURN_IF_ERROR(AsGraphDefInternal(&db, &node)); + TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node)); *output_node = node->name(); GraphDef graph_def; TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def)); @@ -480,9 +563,9 @@ class DatasetIterator : public IteratorBase { return GetNextInternal(ctx, out_tensors, end_of_sequence); } - Status Save(IteratorStateWriter* writer) final { - TF_RETURN_IF_ERROR(dataset()->Save(writer)); - return IteratorBase::Save(writer); + Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) final { + TF_RETURN_IF_ERROR(dataset()->Save(ctx, writer)); + return IteratorBase::Save(ctx, writer); } protected: diff --git a/tensorflow/core/kernels/iterator_ops.cc b/tensorflow/core/kernels/iterator_ops.cc index ae77ae64338..b48da5b3263 100644 --- a/tensorflow/core/kernels/iterator_ops.cc +++ b/tensorflow/core/kernels/iterator_ops.cc @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/dataset.h" - #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/framework/iterator.pb.h" @@ -22,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/kernels/dataset.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -79,10 +78,12 @@ Status VerifyShapesCompatible(const std::vector& expected, class IteratorResource : public ResourceBase { public: IteratorResource(const DataTypeVector& output_dtypes, - const std::vector& output_shapes) + const std::vector& output_shapes, + const int graph_def_version) : iterator_(nullptr), output_dtypes_(output_dtypes), - output_shapes_(output_shapes) {} + output_shapes_(output_shapes), + graph_def_version_(graph_def_version) {} Status GetNext(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) { @@ -97,10 +98,10 @@ class IteratorResource : public ResourceBase { } } - Status Save(IteratorStateWriter* writer) { + Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) { std::shared_ptr captured_iterator(iterator_); if (captured_iterator) { - return captured_iterator->Save(writer); + return captured_iterator->Save(ctx, writer); } else { return errors::FailedPrecondition( "Save() failed because the iterator has not been initialized. " @@ -125,8 +126,21 @@ class IteratorResource : public ResourceBase { TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr)); std::vector outputs; GraphRunner graph_runner(ctx->env()); - TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {}, - {output_node}, &outputs)); + + // Build a new FLR that knows about the functions in the graph. + std::unique_ptr flib_def( + new FunctionLibraryDefinition( + *ctx->function_library()->GetFunctionLibraryDefinition())); + TF_RETURN_IF_ERROR(flib_def->AddLibrary(graph_def.library())); + std::unique_ptr pflr( + new ProcessFunctionLibraryRuntime(nullptr, ctx->env(), + graph_def_version_, flib_def.get(), + {}, nullptr)); + FunctionLibraryRuntime* lib = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + + TF_RETURN_IF_ERROR( + graph_runner.Run(&graph, lib, {}, {output_node}, &outputs)); TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset)); TF_RETURN_IF_ERROR(set_iterator(dataset->MakeIterator("Iterator"))); @@ -166,6 +180,7 @@ class IteratorResource : public ResourceBase { std::shared_ptr iterator_; const DataTypeVector output_dtypes_; const std::vector output_shapes_; + const int graph_def_version_; }; // Helper class for reading data from a VariantTensorData object. @@ -319,11 +334,12 @@ class IteratorStateVariant { } // Initializes this object with the current state of the iterator so // that it can be written on the next call to Encode(). - Status InitializeFromIterator(IteratorResource* iterator_resource) { + Status InitializeFromIterator(OpKernelContext* ctx, + IteratorResource* iterator_resource) { data_.reset(new VariantTensorData()); data_->set_type_name(TypeName()); VariantTensorDataWriter writer(data_.get()); - TF_RETURN_IF_ERROR(iterator_resource->Save(&writer)); + TF_RETURN_IF_ERROR(iterator_resource->Save(ctx, &writer)); TF_RETURN_IF_ERROR(writer.Flush()); return Status::OK(); } @@ -375,7 +391,8 @@ REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant, class IteratorHandleOp : public ResourceOpKernel { public: explicit IteratorHandleOp(OpKernelConstruction* ctx) - : ResourceOpKernel(ctx) { + : ResourceOpKernel(ctx), + graph_def_version_(ctx->graph_def_version()) { OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); } @@ -383,7 +400,8 @@ class IteratorHandleOp : public ResourceOpKernel { private: Status CreateResource(IteratorResource** ret) override EXCLUSIVE_LOCKS_REQUIRED(mu_) { - *ret = new IteratorResource(output_dtypes_, output_shapes_); + *ret = new IteratorResource(output_dtypes_, output_shapes_, + graph_def_version_); return Status::OK(); } @@ -398,6 +416,7 @@ class IteratorHandleOp : public ResourceOpKernel { private: DataTypeVector output_dtypes_; std::vector output_shapes_; + const int graph_def_version_; }; class MakeIteratorOp : public OpKernel { @@ -460,7 +479,8 @@ class OneShotIteratorOp : public AsyncOpKernel { ctx->env(), ThreadOptions(), strings::StrCat("one_shot_iterator_initialization_thread_", SanitizeThreadSuffix(name())), - 1 /* num_threads */, false /* low_latency_hint */)) + 1 /* num_threads */, false /* low_latency_hint */)), + graph_def_version_(ctx->graph_def_version()) { string shared_name; @@ -544,7 +564,8 @@ class OneShotIteratorOp : public AsyncOpKernel { ctx->resource_manager()->LookupOrCreate( cinfo->container(), cinfo->name(), iterator, [this](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - *ret = new IteratorResource(output_dtypes_, output_shapes_); + *ret = new IteratorResource(output_dtypes_, output_shapes_, + graph_def_version_); return Status::OK(); })); @@ -634,6 +655,7 @@ class OneShotIteratorOp : public AsyncOpKernel { Status initialization_status_ GUARDED_BY(mu_); std::vector> done_callbacks_ GUARDED_BY(mu_); + const int graph_def_version_; }; class IteratorGetNextOp : public AsyncOpKernel { @@ -787,7 +809,7 @@ class SerializeIteratorOp : public OpKernel { Tensor* variant_t; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &variant_t)); IteratorStateVariant v; - OP_REQUIRES_OK(ctx, v.InitializeFromIterator(iterator_resource)); + OP_REQUIRES_OK(ctx, v.InitializeFromIterator(ctx, iterator_resource)); variant_t->scalar()() = v; } }; diff --git a/tensorflow/core/kernels/map_dataset_op.cc b/tensorflow/core/kernels/map_dataset_op.cc index ac458701fe2..4ba09bc335e 100644 --- a/tensorflow/core/kernels/map_dataset_op.cc +++ b/tensorflow/core/kernels/map_dataset_op.cc @@ -53,18 +53,21 @@ class MapDatasetOp : public UnaryDatasetOpKernel { std::move(other_arguments), &captured_func)); - *output = new Dataset(input, std::move(captured_func), output_types_, - output_shapes_); + *output = new Dataset(ctx, input, func_, std::move(captured_func), + output_types_, output_shapes_); } private: - class Dataset : public DatasetBase { + class Dataset : public GraphDatasetBase { public: - Dataset(const DatasetBase* input, + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const NameAttrList& func, std::unique_ptr captured_func, const DataTypeVector& output_types, const std::vector& output_shapes) - : input_(input), + : GraphDatasetBase(ctx), + input_(input), + func_(func), captured_func_(std::move(captured_func)), output_types_(output_types), output_shapes_(output_shapes) { @@ -88,6 +91,37 @@ class MapDatasetOp : public UnaryDatasetOpKernel { string DebugString() override { return "MapDatasetOp::Dataset"; } + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); + + DataTypeVector other_arguments_types( + captured_func_->captured_inputs().size()); + std::vector other_arguments( + captured_func_->captured_inputs().size()); + for (const Tensor& t : captured_func_->captured_inputs()) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + other_arguments.emplace_back(node); + other_arguments_types.emplace_back(t.dtype()); + } + AttrValue f; + b->BuildAttrValue(func_, &f); + AttrValue other_arguments_types_attr; + b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); + + TF_RETURN_IF_ERROR(b->AddDataset( + this, {std::make_pair(0, input_graph_node)}, // Single tensor inputs. + {std::make_pair(1, other_arguments)}, // Tensor list inputs. + {std::make_pair("f", f), + std::make_pair("Targuments", other_arguments_types_attr)}, // Attrs + output)); + return Status::OK(); + } + private: class Iterator : public DatasetIterator { public: @@ -133,11 +167,24 @@ class MapDatasetOp : public UnaryDatasetOpKernel { } } + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + return Status::OK(); + } + + Status RestoreInternal(OpKernelContext* ctx, + IteratorStateReader* reader) override { + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + return Status::OK(); + } + private: const std::unique_ptr input_impl_; }; const DatasetBase* const input_; + const NameAttrList func_; const std::unique_ptr captured_func_; const DataTypeVector output_types_; const std::vector output_shapes_; diff --git a/tensorflow/core/kernels/map_stage_op.cc b/tensorflow/core/kernels/map_stage_op.cc index 7b5a464b722..bdc3b5778f0 100644 --- a/tensorflow/core/kernels/map_stage_op.cc +++ b/tensorflow/core/kernels/map_stage_op.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" namespace tensorflow { namespace { @@ -36,16 +37,14 @@ namespace { // Partial Ordering Comparator for Tensor keys containing scalar int64's struct KeyTensorLess { bool operator()(const Tensor& lhs, const Tensor& rhs) const { - return std::less{}(lhs.scalar()(), - rhs.scalar()()); + return std::less{}(lhs.scalar()(), rhs.scalar()()); } }; // Key Equality operator for Tensor keys containing scalar int64's struct KeyTensorEqual { bool operator()(const Tensor& lhs, const Tensor& rhs) const { - return std::equal_to{}(lhs.scalar()(), - rhs.scalar()()); + return std::equal_to{}(lhs.scalar()(), rhs.scalar()()); } }; @@ -93,24 +92,23 @@ class StagingMap : public ResourceBase { private: // Private variables - DataTypeVector dtypes_; - std::size_t capacity_; - std::size_t memory_limit_; - std::size_t current_bytes_; - std::mutex mu_; - std::condition_variable not_empty_; - std::condition_variable full_; - IncompleteType incomplete_; - MapType map_; + DataTypeVector dtypes_ GUARDED_BY(mu_); + std::size_t capacity_ GUARDED_BY(mu_); + std::size_t memory_limit_ GUARDED_BY(mu_); + std::size_t current_bytes_ GUARDED_BY(mu_); + tensorflow::mutex mu_; + tensorflow::condition_variable not_empty_; + tensorflow::condition_variable full_; + IncompleteType incomplete_ GUARDED_BY(mu_); + MapType map_ GUARDED_BY(mu_); private: // private methods // If map is configured for bounded capacity, notify // waiting inserters that space is now available - void notify_inserters_if_bounded(std::unique_lock* lock) { + void notify_inserters_if_bounded() EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (has_capacity() || has_memory_limit()) { - lock->unlock(); // Notify all inserters. The removal of an element // may make memory available for many inserters // to insert new elements @@ -120,23 +118,29 @@ class StagingMap : public ResourceBase { // Notify all removers waiting to extract values // that data is now available - void notify_removers(std::unique_lock* lock) { - lock->unlock(); + void notify_removers() { // Notify all removers. This is because they are // waiting for specific keys to appear in the map // so we don't know which one to wake up. not_empty_.notify_all(); } - bool has_capacity() const { return capacity_ > 0; } - - bool has_memory_limit() const { return memory_limit_ > 0; } - - bool would_exceed_memory_limit(std::size_t bytes) const { - return bytes + current_bytes_ > memory_limit_; + bool has_capacity() const EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return capacity_ > 0; } - bool is_capacity_full() const { return map_.size() >= capacity_; } + bool has_memory_limit() const EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return memory_limit_ > 0; + } + + bool would_exceed_memory_limit(std::size_t bytes) const + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return has_memory_limit() && bytes + current_bytes_ > memory_limit_; + } + + bool is_capacity_full() const EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return has_capacity() && map_.size() >= capacity_; + } // Get number of bytes in the tuple std::size_t get_tuple_bytes(const Tuple& tuple) { @@ -157,7 +161,8 @@ class StagingMap : public ResourceBase { } // Check that the index is within bounds - Status check_index(const Tensor& key, std::size_t index) { + Status check_index(const Tensor& key, std::size_t index) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (index >= dtypes_.size()) { return Status(errors::InvalidArgument( "Index '", index, "' for key '", key.scalar()(), @@ -169,7 +174,7 @@ class StagingMap : public ResourceBase { Status copy_or_move_tensors(OptionalTuple* map_tuple, const Tensor& key, const Tensor& indices, Tuple* output, - bool copy = false) { + bool copy = false) EXCLUSIVE_LOCKS_REQUIRED(mu_) { auto findices = indices.flat(); // Return values at specified indices @@ -201,11 +206,12 @@ class StagingMap : public ResourceBase { // Check that the optional value at the specified index // is uninitialized Status check_index_uninitialized(const Tensor& key, std::size_t index, - const OptionalTuple& tuple) { + const OptionalTuple& tuple) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (tuple[index].has_value()) { - return Status(errors::InvalidArgument("The tensor for index '", - index, "' for key '", key.scalar()(), - "' was already initialized '", dtypes_.size(), "'.")); + return Status(errors::InvalidArgument( + "The tensor for index '", index, "' for key '", key.scalar()(), + "' was already initialized '", dtypes_.size(), "'.")); } return Status::OK(); @@ -228,7 +234,7 @@ class StagingMap : public ResourceBase { } // Check bytes are within memory limits memory limits - Status check_memory_limit(std::size_t bytes) { + Status check_memory_limit(std::size_t bytes) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (has_memory_limit() && bytes > memory_limit_) { return Status(errors::ResourceExhausted( "Attempted to insert tensors with combined size of '", bytes, @@ -241,8 +247,8 @@ class StagingMap : public ResourceBase { // Insert incomplete data into the Barrier Status put_incomplete(const KeyType& key, const Tensor& indices, - OptionalTuple* tuple, - std::unique_lock* lock) { + OptionalTuple* tuple, tensorflow::mutex_lock* lock) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { auto findices = indices.flat(); // Search for the key in our incomplete set @@ -252,11 +258,9 @@ class StagingMap : public ResourceBase { std::size_t tuple_bytes = get_tuple_bytes(*tuple); TF_RETURN_IF_ERROR(check_memory_limit(tuple_bytes)); - if (has_memory_limit()) { - full_.wait(*lock, [tuple_bytes, this]() { - // Stop waiting if we don't exceed the memory limit - return !would_exceed_memory_limit(tuple_bytes); - }); + // Wait until we don't exceed the memory limit + while (would_exceed_memory_limit(tuple_bytes)) { + full_.wait(*lock); } // This key isn't present in the incomplete set @@ -282,8 +286,7 @@ class StagingMap : public ResourceBase { // Found an entry in the incomplete index // Update with given data and insert complete entries // into the main map - else - { + else { // Reference existing incomplete tuple OptionalTuple& present = it->second; @@ -312,7 +315,7 @@ class StagingMap : public ResourceBase { // Remove from incomplete incomplete_.erase(it); - TF_RETURN_IF_ERROR(put_complete(key, &insert_tuple, lock)); + TF_RETURN_IF_ERROR(put_complete(key, &insert_tuple)); } } @@ -320,12 +323,12 @@ class StagingMap : public ResourceBase { } // Does the insertion into the actual staging area - Status put_complete(const KeyType& key, OptionalTuple* tuple, - std::unique_lock* lock) { + Status put_complete(const KeyType& key, OptionalTuple* tuple) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { // Insert key and tuples into the map map_.insert({key, std::move(*tuple)}); - notify_removers(lock); + notify_removers(); return Status::OK(); } @@ -340,7 +343,7 @@ class StagingMap : public ResourceBase { current_bytes_(0) {} Status put(KeyType* key, const Tensor* indices, OptionalTuple* tuple) { - std::unique_lock lock(mu_); + tensorflow::mutex_lock lock(mu_); // Sanity check the indices TF_RETURN_IF_ERROR(check_index_ordering(*indices)); @@ -354,22 +357,13 @@ class StagingMap : public ResourceBase { // Check that tuple_bytes fits within the memory limit TF_RETURN_IF_ERROR(check_memory_limit(tuple_bytes)); - // If map capacity is bounded wait until map is not full - if (has_capacity() || has_memory_limit()) { - full_.wait(lock, [tuple_bytes, this]() { - // If there's a memory limit, check if there's space for insertion - bool memory_limit_valid = - has_memory_limit() ? !would_exceed_memory_limit(tuple_bytes) : true; - // If we're configured for capacity check if there's space for insertion - bool capacity_valid = has_capacity() ? !is_capacity_full() : true; - - // Stop waiting upon success for both conditions - return memory_limit_valid && capacity_valid; - }); + // Wait until there's space for insertion. + while (would_exceed_memory_limit(tuple_bytes) || is_capacity_full()) { + full_.wait(lock); } // Do the put operation - TF_RETURN_IF_ERROR(put_complete(*key, tuple, &lock)); + TF_RETURN_IF_ERROR(put_complete(*key, tuple)); // Update the current size current_bytes_ += tuple_bytes; @@ -378,7 +372,7 @@ class StagingMap : public ResourceBase { } Status get(const KeyType* key, const Tensor* indices, Tuple* tuple) { - std::unique_lock lock(mu_); + tensorflow::mutex_lock lock(mu_); // Sanity check the indices TF_RETURN_IF_ERROR(check_index_ordering(*indices)); @@ -386,8 +380,9 @@ class StagingMap : public ResourceBase { typename MapType::iterator it; // Wait until the element with the requested key is present - not_empty_.wait( - lock, [&, this]() { return (it = map_.find(*key)) != map_.end(); }); + while ((it = map_.find(*key)) == map_.end()) { + not_empty_.wait(lock); + } TF_RETURN_IF_ERROR( copy_or_move_tensors(&it->second, *key, *indices, tuple, true)); @@ -399,7 +394,7 @@ class StagingMap : public ResourceBase { } Status pop(const KeyType* key, const Tensor* indices, Tuple* tuple) { - std::unique_lock lock(mu_); + tensorflow::mutex_lock lock(mu_); // Sanity check the indices TF_RETURN_IF_ERROR(check_index_ordering(*indices)); @@ -407,8 +402,9 @@ class StagingMap : public ResourceBase { typename MapType::iterator it; // Wait until the element with the requested key is present - not_empty_.wait( - lock, [&, this]() { return (it = map_.find(*key)) != map_.end(); }); + while ((it = map_.find(*key)) == map_.end()) { + not_empty_.wait(lock); + } TF_RETURN_IF_ERROR( copy_or_move_tensors(&it->second, *key, *indices, tuple)); @@ -422,19 +418,21 @@ class StagingMap : public ResourceBase { // Update bytes in the Staging Area current_bytes_ -= get_tuple_bytes(*tuple); - notify_inserters_if_bounded(&lock); + notify_inserters_if_bounded(); return Status::OK(); } Status popitem(KeyType* key, const Tensor* indices, Tuple* tuple) { - std::unique_lock lock(mu_); + tensorflow::mutex_lock lock(mu_); // Sanity check the indices TF_RETURN_IF_ERROR(check_index_ordering(*indices)); // Wait until map is not empty - not_empty_.wait(lock, [this]() { return !this->map_.empty(); }); + while (this->map_.empty()) { + not_empty_.wait(lock); + } // Move from the first element and erase it @@ -454,29 +452,29 @@ class StagingMap : public ResourceBase { // Update bytes in the Staging Area current_bytes_ -= get_tuple_bytes(*tuple); - notify_inserters_if_bounded(&lock); + notify_inserters_if_bounded(); return Status::OK(); } Status clear() { - std::unique_lock lock(mu_); + tensorflow::mutex_lock lock(mu_); map_.clear(); incomplete_.clear(); current_bytes_ = 0; - notify_inserters_if_bounded(&lock); + notify_inserters_if_bounded(); return Status::OK(); } std::size_t incomplete_size() { - std::unique_lock lock(mu_); + tensorflow::mutex_lock lock(mu_); return incomplete_.size(); } std::size_t size() { - std::unique_lock lock(mu_); + tensorflow::mutex_lock lock(mu_); return map_.size(); } @@ -539,10 +537,9 @@ class MapStageOp : public OpKernel { } }; -REGISTER_KERNEL_BUILDER(Name("MapStage").Device(DEVICE_CPU), - MapStageOp); +REGISTER_KERNEL_BUILDER(Name("MapStage").Device(DEVICE_CPU), MapStageOp); REGISTER_KERNEL_BUILDER(Name("OrderedMapStage").Device(DEVICE_CPU), - MapStageOp); + MapStageOp); #if GOOGLE_CUDA REGISTER_KERNEL_BUILDER( @@ -553,7 +550,7 @@ REGISTER_KERNEL_BUILDER(Name("OrderedMapStage") .HostMemory("indices") .Device(DEVICE_GPU), MapStageOp); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER(Name("MapStage") @@ -601,30 +598,34 @@ class MapUnstageOp : public OpKernel { }; REGISTER_KERNEL_BUILDER(Name("MapUnstage").Device(DEVICE_CPU), - MapUnstageOp); + MapUnstageOp); REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage").Device(DEVICE_CPU), - MapUnstageOp); + MapUnstageOp); #if GOOGLE_CUDA REGISTER_KERNEL_BUILDER(Name("MapUnstage") - .HostMemory("key") - .HostMemory("indices") - .Device(DEVICE_GPU), MapUnstageOp); + .HostMemory("key") + .HostMemory("indices") + .Device(DEVICE_GPU), + MapUnstageOp); REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage") - .HostMemory("key") - .HostMemory("indices") - .Device(DEVICE_GPU), MapUnstageOp); + .HostMemory("key") + .HostMemory("indices") + .Device(DEVICE_GPU), + MapUnstageOp); #endif #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER(Name("MapUnstage") - .HostMemory("key") - .HostMemory("indices") - .Device(DEVICE_SYCL), MapUnstageOp); + .HostMemory("key") + .HostMemory("indices") + .Device(DEVICE_SYCL), + MapUnstageOp); REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage") - .HostMemory("key") - .HostMemory("indices") - .Device(DEVICE_SYCL), MapUnstageOp); -#endif // TENSORFLOW_USE_SYCL + .HostMemory("key") + .HostMemory("indices") + .Device(DEVICE_SYCL), + MapUnstageOp); +#endif // TENSORFLOW_USE_SYCL template class MapPeekOp : public OpKernel { @@ -682,7 +683,7 @@ REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek") .HostMemory("indices") .Device(DEVICE_SYCL), MapPeekOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL template class MapUnstageNoKeyOp : public OpKernel { @@ -715,7 +716,7 @@ class MapUnstageNoKeyOp : public OpKernel { " vs. ", indices_tensor->NumElements())); for (std::size_t i = 0; i < tuple.size(); ++i) { - ctx->set_output(i+1, tuple[i]); + ctx->set_output(i + 1, tuple[i]); } } }; @@ -749,7 +750,7 @@ REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey") .HostMemory("indices") .Device(DEVICE_SYCL), MapUnstageNoKeyOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL template class MapSizeOp : public OpKernel { @@ -770,23 +771,24 @@ class MapSizeOp : public OpKernel { } }; -REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_CPU), - MapSizeOp); +REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_CPU), MapSizeOp); REGISTER_KERNEL_BUILDER(Name("OrderedMapSize").Device(DEVICE_CPU), MapSizeOp); #if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_GPU) - .HostMemory("size"), MapSizeOp); -REGISTER_KERNEL_BUILDER(Name("OrderedMapSize").Device(DEVICE_GPU) - .HostMemory("size"), MapSizeOp); +REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_GPU).HostMemory("size"), + MapSizeOp); +REGISTER_KERNEL_BUILDER( + Name("OrderedMapSize").Device(DEVICE_GPU).HostMemory("size"), + MapSizeOp); #endif #ifdef TENSORFLOW_USE_SYCL -REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_SYCL) - .HostMemory("size"), MapSizeOp); -REGISTER_KERNEL_BUILDER(Name("OrderedMapSize").Device(DEVICE_SYCL) - .HostMemory("size"), MapSizeOp); -#endif // TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_SYCL).HostMemory("size"), + MapSizeOp); +REGISTER_KERNEL_BUILDER( + Name("OrderedMapSize").Device(DEVICE_SYCL).HostMemory("size"), + MapSizeOp); +#endif // TENSORFLOW_USE_SYCL template class MapIncompleteSizeOp : public OpKernel { @@ -813,17 +815,21 @@ REGISTER_KERNEL_BUILDER(Name("OrderedMapIncompleteSize").Device(DEVICE_CPU), MapIncompleteSizeOp); #if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("MapIncompleteSize").Device(DEVICE_GPU) - .HostMemory("size"), MapIncompleteSizeOp); -REGISTER_KERNEL_BUILDER(Name("OrderedMapIncompleteSize").Device(DEVICE_GPU) - .HostMemory("size"), MapIncompleteSizeOp); +REGISTER_KERNEL_BUILDER( + Name("MapIncompleteSize").Device(DEVICE_GPU).HostMemory("size"), + MapIncompleteSizeOp); +REGISTER_KERNEL_BUILDER( + Name("OrderedMapIncompleteSize").Device(DEVICE_GPU).HostMemory("size"), + MapIncompleteSizeOp); #endif #ifdef TENSORFLOW_USE_SYCL -REGISTER_KERNEL_BUILDER(Name("MapIncompleteSize").Device(DEVICE_SYCL) - .HostMemory("size"), MapIncompleteSizeOp); -REGISTER_KERNEL_BUILDER(Name("OrderedMapIncompleteSize").Device(DEVICE_SYCL) - .HostMemory("size"), MapIncompleteSizeOp); -#endif // TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER( + Name("MapIncompleteSize").Device(DEVICE_SYCL).HostMemory("size"), + MapIncompleteSizeOp); +REGISTER_KERNEL_BUILDER( + Name("OrderedMapIncompleteSize").Device(DEVICE_SYCL).HostMemory("size"), + MapIncompleteSizeOp); +#endif // TENSORFLOW_USE_SYCL template class MapClearOp : public OpKernel { @@ -839,14 +845,12 @@ class MapClearOp : public OpKernel { } }; -REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_CPU), - MapClearOp); +REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_CPU), MapClearOp); REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_CPU), MapClearOp); #if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_GPU), - MapClearOp); +REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_GPU), MapClearOp); REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_GPU), MapClearOp); #endif @@ -855,7 +859,7 @@ REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_SYCL), MapClearOp); REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_SYCL), MapClearOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace } // namespace tensorflow diff --git a/tensorflow/core/kernels/repeat_dataset_op.cc b/tensorflow/core/kernels/repeat_dataset_op.cc index 0167b9ea64b..3d977a0fa38 100644 --- a/tensorflow/core/kernels/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/repeat_dataset_op.cc @@ -73,10 +73,10 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { string DebugString() override { return "RepeatDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddParentDataset(input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); Node* count = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(count_, &count)); TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/kernels/reverse_op.cc b/tensorflow/core/kernels/reverse_op.cc index 4f2afa52579..7ac34d1c623 100644 --- a/tensorflow/core/kernels/reverse_op.cc +++ b/tensorflow/core/kernels/reverse_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/type_traits.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/core/status.h" @@ -35,7 +36,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL namespace { @@ -43,7 +44,7 @@ namespace { // NUM_CHANNELS can be <= 0 to compute it dynamically from // Otherwise, it must equal input.dim_size(2) and is used as a compile-time // constant. -template +template void ReverseRows(OpKernelContext* context, const Tensor& input, Tensor* result) { auto work = [&input, result](int64 start, int64 end) { @@ -53,8 +54,8 @@ void ReverseRows(OpKernelContext* context, const Tensor& input, const int64 row_size = inner_size * middle_size; DCHECK_EQ(input.dim_size(2), inner_size); - const int32* in_ptr = input.bit_casted_tensor().data(); - int32* out_ptr = result->bit_casted_tensor().data(); + const T* in_ptr = input.bit_casted_tensor().data(); + T* out_ptr = result->bit_casted_tensor().data(); in_ptr += start * row_size; out_ptr += start * row_size; @@ -64,7 +65,7 @@ void ReverseRows(OpKernelContext* context, const Tensor& input, int remaining = middle_size; while (remaining > 0) { out_ptr -= inner_size; - memcpy(out_ptr, in_ptr, inner_size * sizeof(float)); + memcpy(out_ptr, in_ptr, inner_size * sizeof(T)); in_ptr += inner_size; --remaining; } @@ -81,6 +82,48 @@ void ReverseRows(OpKernelContext* context, const Tensor& input, std::move(work)); } +template +struct data_type_can_memcpy { + static constexpr bool value = + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value; +}; + +template +typename std::enable_if::value>::type +DoHandleReverseCase(OpKernelContext* context, const Tensor& input, + Tensor* result) { + if (sizeof(T) == 1) { + static_assert(sizeof(uint8) == 1, "uint8 must be 1 byte."); + ReverseRows(context, input, result); + } else if (sizeof(T) == 2) { + static_assert(sizeof(uint16) == 2, "uint16 must be 2 bytes"); + ReverseRows(context, input, result); + } else if (sizeof(T) == 4) { + static_assert(sizeof(uint32) == 4, "uint32 must be 4 bytes"); + ReverseRows(context, input, result); + } else if (sizeof(T) == 8) { + static_assert(sizeof(uint64) == 8, "uint64 must be 8 bytes"); + ReverseRows(context, input, result); + } else if (sizeof(T) == 16) { + static_assert(sizeof(complex128) == 16, "complex128 must be 16 bytes"); + ReverseRows(context, input, result); + } else { + context->CtxFailure( + errors::InvalidArgument("%s has unexpected size of %d bytes", + DataTypeString(input.dtype()), sizeof(T))); + } +} + +template +typename std::enable_if::value>::type +DoHandleReverseCase(OpKernelContext* context, const Tensor& input, + Tensor* result) {} + } // namespace template @@ -91,15 +134,14 @@ void HandleReverseCase(OpKernelContext* context, // Use optimized reverse if possible. if (NDIMS == 3 && std::is_same::value && - std::is_same::value && (!dims(0) && dims(1) && !dims(2))) { + data_type_can_memcpy::value && (!dims(0) && dims(1) && !dims(2))) { if (input.dim_size(2) == 3) { - ReverseRows<3>(context, input, result); + DoHandleReverseCase(context, input, result); } else { - ReverseRows<-1>(context, input, result); + DoHandleReverseCase(context, input, result); } return; } - typename Eigen::array axes_di; for (int i = 0; i < NDIMS; i++) { axes_di[i] = dims(i); @@ -168,11 +210,11 @@ void HandleReverseV2Case(OpKernelContext* context, // Use optimized reverse if possible. if (NDIMS == 3 && std::is_same::value && - std::is_same::value && (!axes[0] && axes[1] && !axes[2])) { + data_type_can_memcpy::value && (!axes[0] && axes[1] && !axes[2])) { if (input.dim_size(2) == 3) { - ReverseRows<3>(context, input, result); + DoHandleReverseCase(context, input, result); } else { - ReverseRows<-1>(context, input, result); + DoHandleReverseCase(context, input, result); } return; } diff --git a/tensorflow/core/kernels/reverse_op_test.cc b/tensorflow/core/kernels/reverse_op_test.cc index 9829e40fe85..e8285fb0e24 100644 --- a/tensorflow/core/kernels/reverse_op_test.cc +++ b/tensorflow/core/kernels/reverse_op_test.cc @@ -46,69 +46,132 @@ class ReverseOpTest : public OpsTestBase { .Finalize(node_def())); TF_ASSERT_OK(InitOp()); } + + template + void Reverse_0() { + MakeOp(DataTypeToEnum::value); + AddInputFromArray(TensorShape({}), {3}); + AddInputFromArray(TensorShape({}), {true}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output = GetOutput(0); + Tensor expected(allocator(), DataTypeToEnum::value, TensorShape({})); + expected.scalar() = expected.scalar().constant(3); + test::ExpectTensorEqual(expected, *output); + } + + template + void Reverse_234() { + MakeOp(DataTypeToEnum::value); + // Feed and run + // [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] + // [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] + AddInputFromArray(TensorShape({2, 3, 4}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); + AddInputFromArray(TensorShape({3}), {true, false, true}); + + TF_ASSERT_OK(RunOpKernel()); + + // Check the new state of the input + Tensor* params_tensor = GetOutput(0); + Tensor expected(allocator(), DataTypeToEnum::value, + TensorShape({2, 3, 4})); + // Should become + // [[[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]] + // [[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]]] + test::FillValues(&expected, + {15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, + 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8}); + test::ExpectTensorEqual(expected, *params_tensor); + } + + template + void Reverse_1234() { + MakeOp(DataTypeToEnum::value); + // Feed and run + // [[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] + // [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]] + AddInputFromArray(TensorShape({1, 2, 3, 4}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); + AddInputFromArray(TensorShape({4}), {true, true, false, true}); + + TF_ASSERT_OK(RunOpKernel()); + + // Check the new state of the input + Tensor* params_tensor = GetOutput(0); + Tensor expected(allocator(), DataTypeToEnum::value, + TensorShape({1, 2, 3, 4})); + // Should become + // [[[[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]] + // [[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]]]] + test::FillValues(&expected, + {15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, + 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8}); + test::ExpectTensorEqual(expected, *params_tensor); + } }; -TEST_F(ReverseOpTest, Reverse_0) { - MakeOp(DT_FLOAT); - AddInputFromArray(TensorShape({}), {3}); - AddInputFromArray(TensorShape({}), {true}); - TF_ASSERT_OK(RunOpKernel()); +TEST_F(ReverseOpTest, Reverse_0_uint8) { Reverse_0(); } - Tensor* output = GetOutput(0); - Tensor expected(allocator(), DT_FLOAT, TensorShape({})); - expected.scalar() = expected.scalar().constant(3.f); - test::ExpectTensorEqual(expected, *output); -} +TEST_F(ReverseOpTest, Reverse_0_int8) { Reverse_0(); } -TEST_F(ReverseOpTest, Reverse_234) { - MakeOp(DT_FLOAT); +TEST_F(ReverseOpTest, Reverse_0_uint16) { Reverse_0(); } - // Feed and run - // [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] - // [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] - AddInputFromArray(TensorShape({2, 3, 4}), - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, - 15, 16, 17, 18, 19, 20, 21, 22, 23}); - AddInputFromArray(TensorShape({3}), {true, false, true}); +TEST_F(ReverseOpTest, Reverse_0_int16) { Reverse_0(); } - TF_ASSERT_OK(RunOpKernel()); +TEST_F(ReverseOpTest, Reverse_0_float) { Reverse_0(); } - // Check the new state of the input - Tensor* params_tensor = GetOutput(0); - Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3, 4})); - // Should become - // [[[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]] - // [[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]]] - test::FillValues( - &expected, {15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, 3, 2, 1, 0, 7, - 6, 5, 4, 11, 10, 9, 8}); - test::ExpectTensorEqual(expected, *params_tensor); -} +TEST_F(ReverseOpTest, Reverse_0_int32) { Reverse_0(); } -TEST_F(ReverseOpTest, Reverse_1234) { - MakeOp(DT_FLOAT); +TEST_F(ReverseOpTest, Reverse_0_int64) { Reverse_0(); } - // Feed and run - // [[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] - // [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]] - AddInputFromArray(TensorShape({1, 2, 3, 4}), - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, - 15, 16, 17, 18, 19, 20, 21, 22, 23}); - AddInputFromArray(TensorShape({4}), {true, true, false, true}); +TEST_F(ReverseOpTest, Reverse_0_double) { Reverse_0(); } - TF_ASSERT_OK(RunOpKernel()); +TEST_F(ReverseOpTest, Reverse_0_complex64) { Reverse_0(); } - // Check the new state of the input - Tensor* params_tensor = GetOutput(0); - Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 2, 3, 4})); - // Should become - // [[[[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]] - // [[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]]]] - test::FillValues( - &expected, {15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, 3, 2, 1, 0, 7, - 6, 5, 4, 11, 10, 9, 8}); - test::ExpectTensorEqual(expected, *params_tensor); -} +TEST_F(ReverseOpTest, Reverse_0_complex128) { Reverse_0(); } + +TEST_F(ReverseOpTest, Reverse_234_uint8) { Reverse_234(); } + +TEST_F(ReverseOpTest, Reverse_234_int8) { Reverse_234(); } + +TEST_F(ReverseOpTest, Reverse_234_uint16) { Reverse_234(); } + +TEST_F(ReverseOpTest, Reverse_234_int16) { Reverse_234(); } + +TEST_F(ReverseOpTest, Reverse_234_float) { Reverse_234(); } + +TEST_F(ReverseOpTest, Reverse_234_int32) { Reverse_234(); } + +TEST_F(ReverseOpTest, Reverse_234_int64) { Reverse_234(); } + +TEST_F(ReverseOpTest, Reverse_234_double) { Reverse_234(); } + +TEST_F(ReverseOpTest, Reverse_234_complex64) { Reverse_234(); } + +TEST_F(ReverseOpTest, Reverse_234_complex128) { Reverse_234(); } + +TEST_F(ReverseOpTest, Reverse_1234_uint8) { Reverse_1234(); } + +TEST_F(ReverseOpTest, Reverse_1234_int8) { Reverse_1234(); } + +TEST_F(ReverseOpTest, Reverse_1234_uint16) { Reverse_1234(); } + +TEST_F(ReverseOpTest, Reverse_1234_int16) { Reverse_1234(); } + +TEST_F(ReverseOpTest, Reverse_1234_float) { Reverse_1234(); } + +TEST_F(ReverseOpTest, Reverse_1234_int32) { Reverse_1234(); } + +TEST_F(ReverseOpTest, Reverse_1234_int64) { Reverse_1234(); } + +TEST_F(ReverseOpTest, Reverse_1234_double) { Reverse_1234(); } + +TEST_F(ReverseOpTest, Reverse_1234_complex64) { Reverse_1234(); } + +TEST_F(ReverseOpTest, Reverse_1234_complex128) { Reverse_1234(); } static SessionOptions GetOptions(int intra_threads) { SessionOptions opts; @@ -119,10 +182,11 @@ static SessionOptions GetOptions(int intra_threads) { // Creates a Graph which "reduce"s a 3D float tensor of "num" elements // into a scalar. +template static Graph* Reverse(const TensorShape& shape, int reverse_axis) { Graph* g = new Graph(OpRegistry::Global()); - Tensor data(DT_FLOAT, shape); - data.flat().setRandom(); + Tensor data(DataTypeToEnum::value, shape); + data.flat().setRandom(); Tensor axes(DT_INT32, TensorShape({1})); axes.flat()(0) = reverse_axis; test::graph::Reverse(g, test::graph::Constant(g, data), @@ -130,81 +194,149 @@ static Graph* Reverse(const TensorShape& shape, int reverse_axis) { return g; } +template static void RunReverseRowsBenchmark(int iters, int outer_dim, int middle_dim, int intra_threads, int channels) { SessionOptions opts = GetOptions(intra_threads); TensorShape shape{outer_dim, middle_dim, channels}; const int64 num_items = static_cast(iters) * shape.num_elements(); testing::ItemsProcessed(num_items); - testing::BytesProcessed(num_items * sizeof(float)); + testing::BytesProcessed(num_items * sizeof(T)); testing::UseRealTime(); - test::Benchmark("cpu", Reverse(shape, 1), &opts).Run(iters); + test::Benchmark("cpu", Reverse(shape, 1), &opts).Run(iters); } -static void BM_ReverseRowsOf1Channel_1T(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark(iters, outer_dim, middle_dim, 1 /* intra_threads */, - 1 /* channels */); +static void BM_ReverseRowsOf1Channel_1T_float(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 1 /* intra_threads */, 1 /* channels */); } -BENCHMARK(BM_ReverseRowsOf1Channel_1T) +BENCHMARK(BM_ReverseRowsOf1Channel_1T_float) ->ArgPair(288, 288) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); -static void BM_ReverseRowsOf1Channel_4T(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark(iters, outer_dim, middle_dim, 4 /* intra_threads */, - 1 /* channels */); +static void BM_ReverseRowsOf1Channel_1T_uint8(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 1 /* intra_threads */, 1 /* channels */); } -BENCHMARK(BM_ReverseRowsOf1Channel_4T) +BENCHMARK(BM_ReverseRowsOf1Channel_1T_uint8) ->ArgPair(288, 288) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); -static void BM_ReverseRowsOf3Channels_1T(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark(iters, outer_dim, middle_dim, 1 /* intra_threads */, - 3 /* channels */); +static void BM_ReverseRowsOf1Channel_4T_float(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 4 /* intra_threads */, 1 /* channels */); } -BENCHMARK(BM_ReverseRowsOf3Channels_1T) - ->ArgPair(288, 288) - ->ArgPair(224, 224) - ->ArgPair(1024, 1024) - ->ArgPair(10 * 1024, 1024); - -static void BM_ReverseRowsOf3Channels_4T(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark(iters, outer_dim, middle_dim, 4 /* intra_threads */, - 3 /* channels */); -} - -BENCHMARK(BM_ReverseRowsOf3Channels_4T) - ->ArgPair(288, 288) - ->ArgPair(224, 224) - ->ArgPair(1024, 1024) - ->ArgPair(10 * 1024, 1024); - -static void BM_ReverseRowsOf4Channels_1T(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark(iters, outer_dim, middle_dim, 1 /* intra_threads */, - 4 /* channels */); -} - -BENCHMARK(BM_ReverseRowsOf4Channels_1T) +BENCHMARK(BM_ReverseRowsOf1Channel_4T_float) ->ArgPair(288, 288) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); -static void BM_ReverseRowsOf4Channels_4T(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark(iters, outer_dim, middle_dim, 4 /* intra_threads */, - 4 /* channels */); +static void BM_ReverseRowsOf1Channel_4T_uint8(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 4 /* intra_threads */, 1 /* channels */); } -BENCHMARK(BM_ReverseRowsOf4Channels_4T) +BENCHMARK(BM_ReverseRowsOf1Channel_4T_uint8) + ->ArgPair(288, 288) + ->ArgPair(1024, 1024) + ->ArgPair(10 * 1024, 1024); + +static void BM_ReverseRowsOf3Channels_1T_float(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 1 /* intra_threads */, 3 /* channels */); +} + +BENCHMARK(BM_ReverseRowsOf3Channels_1T_float) + ->ArgPair(288, 288) + ->ArgPair(30, 30) + ->ArgPair(1024, 1024) + ->ArgPair(10 * 1024, 1024); + +static void BM_ReverseRowsOf3Channels_1T_uint8(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 1 /* intra_threads */, 3 /* channels */); +} + +BENCHMARK(BM_ReverseRowsOf3Channels_1T_uint8) + ->ArgPair(288, 288) + ->ArgPair(30, 30) + ->ArgPair(1024, 1024) + ->ArgPair(10 * 1024, 1024); + +static void BM_ReverseRowsOf3Channels_4T_float(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 4 /* intra_threads */, 3 /* channels */); +} + +BENCHMARK(BM_ReverseRowsOf3Channels_4T_float) + ->ArgPair(288, 288) + ->ArgPair(30, 30) + ->ArgPair(1024, 1024) + ->ArgPair(10 * 1024, 1024); + +static void BM_ReverseRowsOf3Channels_4T_uint8(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 4 /* intra_threads */, 3 /* channels */); +} +BENCHMARK(BM_ReverseRowsOf3Channels_4T_uint8) + ->ArgPair(288, 288) + ->ArgPair(30, 30) + ->ArgPair(1024, 1024) + ->ArgPair(10 * 1024, 1024); + +static void BM_ReverseRowsOf4Channels_1T_float(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 1 /* intra_threads */, 4 /* channels */); +} + +BENCHMARK(BM_ReverseRowsOf4Channels_1T_float) + ->ArgPair(288, 288) + ->ArgPair(1024, 1024) + ->ArgPair(10 * 1024, 1024); + +static void BM_ReverseRowsOf4Channels_1T_uint8(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 1 /* intra_threads */, 4 /* channels */); +} + +BENCHMARK(BM_ReverseRowsOf4Channels_1T_uint8) + ->ArgPair(288, 288) + ->ArgPair(1024, 1024) + ->ArgPair(10 * 1024, 1024); + +static void BM_ReverseRowsOf4Channels_4T_float(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 4 /* intra_threads */, 4 /* channels */); +} + +BENCHMARK(BM_ReverseRowsOf4Channels_4T_float) + ->ArgPair(288, 288) + ->ArgPair(1024, 1024) + ->ArgPair(10 * 1024, 1024); + +static void BM_ReverseRowsOf4Channels_4T_uint8(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 4 /* intra_threads */, 4 /* channels */); +} + +BENCHMARK(BM_ReverseRowsOf4Channels_4T_uint8) ->ArgPair(288, 288) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); diff --git a/tensorflow/core/kernels/sendrecv_ops.cc b/tensorflow/core/kernels/sendrecv_ops.cc index 9c242052f7c..542382872cc 100644 --- a/tensorflow/core/kernels/sendrecv_ops.cc +++ b/tensorflow/core/kernels/sendrecv_ops.cc @@ -91,9 +91,9 @@ void SendOp::Compute(OpKernelContext* ctx) { if (frame_iter == FrameAndIter(0, 0)) { // Use the cached rendezvous key. VLOG(2) << "Send " << parsed_key_.buf_; - OP_REQUIRES_OK(ctx, - ctx->rendezvous()->Send(parsed_key_, args, ctx->input(0), + ctx->SetStatus(ctx->rendezvous()->Send(parsed_key_, args, ctx->input(0), ctx->is_input_dead())); + return; } else { Rendezvous::ParsedKey in_loop_parsed; GetRendezvousKey(key_prefix_, frame_iter, &in_loop_parsed.buf_); @@ -101,9 +101,9 @@ void SendOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(in_loop_parsed.buf_, &in_loop_parsed)); - OP_REQUIRES_OK(ctx, - ctx->rendezvous()->Send(in_loop_parsed, args, ctx->input(0), + ctx->SetStatus(ctx->rendezvous()->Send(in_loop_parsed, args, ctx->input(0), ctx->is_input_dead())); + return; } } diff --git a/tensorflow/core/kernels/shuffle_dataset_op.cc b/tensorflow/core/kernels/shuffle_dataset_op.cc index dd0ab57e9dc..72facb3a0d0 100644 --- a/tensorflow/core/kernels/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/shuffle_dataset_op.cc @@ -308,10 +308,10 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddParentDataset(input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); Node* buffer_size = nullptr; Node* seed = nullptr; Node* seed2 = nullptr; diff --git a/tensorflow/core/kernels/skip_dataset_op.cc b/tensorflow/core/kernels/skip_dataset_op.cc index 7ee945dd4c4..1fe49271e29 100644 --- a/tensorflow/core/kernels/skip_dataset_op.cc +++ b/tensorflow/core/kernels/skip_dataset_op.cc @@ -72,10 +72,10 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { string DebugString() override { return "SkipDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddParentDataset(input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); Node* count = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(count_, &count)); TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/kernels/summary_interface.cc b/tensorflow/core/kernels/summary_interface.cc index cd366f8c137..ad28d77ffde 100644 --- a/tensorflow/core/kernels/summary_interface.cc +++ b/tensorflow/core/kernels/summary_interface.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/summary.pb.h" @@ -393,6 +394,15 @@ class SummaryWriterImpl : public SummaryWriterInterface { return WriteEvent(std::move(e)); } + Status WriteGraph(int64 global_step, + std::unique_ptr graph) override { + std::unique_ptr e{new Event}; + e->set_step(global_step); + e->set_wall_time(GetWallTime()); + graph->SerializeToString(e->mutable_graph_def()); + return WriteEvent(std::move(e)); + } + Status WriteEvent(std::unique_ptr event) override { mutex_lock ml(mu_); queue_.emplace_back(std::move(event)); diff --git a/tensorflow/core/kernels/summary_interface.h b/tensorflow/core/kernels/summary_interface.h index ccf3459e56b..da1c28709fb 100644 --- a/tensorflow/core/kernels/summary_interface.h +++ b/tensorflow/core/kernels/summary_interface.h @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/util/event.pb.h" @@ -46,6 +47,9 @@ class SummaryWriterInterface : public ResourceBase { virtual Status WriteAudio(int64 global_step, Tensor t, const string& tag, int max_outputs_, float sample_rate) = 0; + virtual Status WriteGraph(int64 global_step, + std::unique_ptr graph) = 0; + virtual Status WriteEvent(std::unique_ptr e) = 0; }; diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc index 1fe2fc5b666..3706f51cf40 100644 --- a/tensorflow/core/kernels/summary_kernels.cc +++ b/tensorflow/core/kernels/summary_kernels.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/kernels/summary_interface.h" @@ -268,4 +269,28 @@ class WriteAudioSummaryOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("WriteAudioSummary").Device(DEVICE_CPU), WriteAudioSummaryOp); +class WriteGraphSummaryOp : public OpKernel { + public: + explicit WriteGraphSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + SummaryWriterInterface* s; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); + core::ScopedUnref unref(s); + const Tensor* t; + OP_REQUIRES_OK(ctx, ctx->input("global_step", &t)); + const int64 global_step = t->scalar()(); + OP_REQUIRES_OK(ctx, ctx->input("tensor", &t)); + std::unique_ptr graph{new GraphDef}; + if (!ParseProtoUnlimited(graph.get(), t->scalar()())) { + ctx->CtxFailureWithWarning( + errors::DataLoss("Bad tf.GraphDef binary proto tensor string")); + return; + } + OP_REQUIRES_OK(ctx, s->WriteGraph(global_step, std::move(graph))); + } +}; +REGISTER_KERNEL_BUILDER(Name("WriteGraphSummary").Device(DEVICE_CPU), + WriteGraphSummaryOp); + } // namespace tensorflow diff --git a/tensorflow/core/kernels/take_dataset_op.cc b/tensorflow/core/kernels/take_dataset_op.cc index fb294a96b15..7a6d20d6c7c 100644 --- a/tensorflow/core/kernels/take_dataset_op.cc +++ b/tensorflow/core/kernels/take_dataset_op.cc @@ -73,10 +73,10 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { string DebugString() override { return "TakeDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddParentDataset(input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); Node* count = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(count_, &count)); TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h index 2a41d4c419a..90b71e370c4 100644 --- a/tensorflow/core/kernels/tensor_array.h +++ b/tensorflow/core/kernels/tensor_array.h @@ -138,8 +138,9 @@ class TensorArray : public ResourceBase { // users to construct this many Tensors for storage in a TensorArray. TensorArray(const string& key, const DataType& dtype, const Tensor& handle, int32 N, const PartialTensorShape& element_shape, - bool dynamic_size, bool multiple_writes_aggregate, bool is_grad, - int32 marked_size, bool clear_after_read) + bool identical_element_shapes, bool dynamic_size, + bool multiple_writes_aggregate, bool is_grad, int32 marked_size, + bool clear_after_read) : key_(key), dtype_(dtype), handle_(handle), @@ -151,6 +152,7 @@ class TensorArray : public ResourceBase { is_grad_(is_grad), marked_size_(marked_size), element_shape_(element_shape), + identical_element_shapes_(identical_element_shapes), tensors_(N) {} // Write PersistentTensor 'value' to index 'index'. @@ -320,6 +322,8 @@ class TensorArray : public ResourceBase { return !gradients_disallowed_; } + bool HasIdenticalElementShapes() const { return identical_element_shapes_; } + // Copy the TensorShapes from another TensorArray into this one. // The sizes of the two TensorArrays must match and this one // may not have any entries filled in. This performs a "soft copy", @@ -379,7 +383,7 @@ class TensorArray : public ResourceBase { // Multiple writes to the same index will result in summation of the // values (used by backprop) - bool multiple_writes_aggregate_; + const bool multiple_writes_aggregate_; // If multiple Writes were attempted (e.g. via attribute // multiple_writes_aggregate), then gradients are disallowed. @@ -387,10 +391,10 @@ class TensorArray : public ResourceBase { // After a read at an index, clear away its PersistentTensor to // release memory. - bool clear_after_read_; + const bool clear_after_read_; // True iff this is a gradient tensor array. - bool is_grad_; + const bool is_grad_; // The size of the TensorArray after a (legacy) unpack or split is performed. // -1 if there has been no unpack or split performed on the TensorArray. @@ -400,6 +404,13 @@ class TensorArray : public ResourceBase { // known at all. PartialTensorShape element_shape_ GUARDED_BY(mu_); + // Whether all elements in the TensorArray have identical shapes. + // This allows certain behaviors, like dynamically checking for + // consistent shapes on write, and being able to fill in properly + // shaped zero tensors on stack -- even if the initial element_shape + // was not fully defined. + const bool identical_element_shapes_; + // TensorAndState is used to keep track of the PersistentTensors // stored in the TensorArray, along with their shapes, and a boolean // that determines whether they have already been read or not. @@ -463,6 +474,8 @@ Status TensorArray::LockedWriteOrAggregate(OpKernelContext* ctx, " which is incompatible with the TensorArray's inferred element " "shape: ", element_shape_.DebugString(), " (consider setting infer_shape=False)."); + } else if (identical_element_shapes_ && !element_shape_.IsFullyDefined()) { + element_shape_ = PartialTensorShape(value_t->shape().dim_sizes()); } if (t.read) { diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index 2191e4e8c5f..cca6d0e35f2 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -162,6 +162,14 @@ class TensorArrayOp : public TensorArrayCreationOp { OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_)); OP_REQUIRES_OK(context, context->GetAttr("element_shape", &element_shape_)); OP_REQUIRES_OK(context, context->GetAttr("dynamic_size", &dynamic_size_)); + // The HasAttr check is for backwards compatibility with older op + // versions which do not have this attribute. + if (context->HasAttr("identical_element_shapes")) { + OP_REQUIRES_OK(context, context->GetAttr("identical_element_shapes", + &identical_element_shapes_)); + } else { + identical_element_shapes_ = false; + } OP_REQUIRES_OK(context, context->GetAttr("clear_after_read", &clear_after_read_)); OP_REQUIRES_OK(context, @@ -196,8 +204,9 @@ class TensorArrayOp : public TensorArrayCreationOp { TensorArray* tensor_array = new TensorArray( key, dtype_, *tensor_array_output_handle, size, element_shape_, - dynamic_size_, false /* multiple_writes_aggregate */, - false /* is_grad */, -1 /* marked_size */, clear_after_read_); + identical_element_shapes_, dynamic_size_, + false /* multiple_writes_aggregate */, false /* is_grad */, + -1 /* marked_size */, clear_after_read_); TF_RETURN_IF_ERROR( rm->Create(ctx->step_container()->name(), key, tensor_array)); @@ -210,6 +219,7 @@ class TensorArrayOp : public TensorArrayCreationOp { private: DataType dtype_; PartialTensorShape element_shape_; + bool identical_element_shapes_; bool dynamic_size_; bool clear_after_read_; string tensor_array_name_; // The name used to create the TensorArray. @@ -322,7 +332,8 @@ class TensorArrayGradOp : public TensorArrayCreationOp { output_handle](TensorArray** ret) -> Status { *ret = new TensorArray( key, tensor_array->ElemType(), *tensor_array_output_handle, - array_size, tensor_array->ElemShape(), false /* dynamic_size */, + array_size, tensor_array->ElemShape(), + tensor_array->HasIdenticalElementShapes(), false /* dynamic_size */, true /* multiple_writes_aggregate */, true /* is_grad */, marked_size /* marked_size */, true /* close_after_read */); TF_RETURN_IF_ERROR((*ret)->CopyShapesFrom(tensor_array)); @@ -1003,8 +1014,9 @@ class TensorArrayUnpackOrScatterOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->input("value", &tensor_value)); TensorShape element_shape(tensor_value->shape()); - OP_REQUIRES(ctx, FastBoundsCheck(element_shape.dim_size(0), - std::numeric_limits::max()), + OP_REQUIRES(ctx, + FastBoundsCheck(element_shape.dim_size(0), + std::numeric_limits::max()), errors::InvalidArgument("tensor dim0 too large to unpack")); OP_REQUIRES( @@ -1204,8 +1216,9 @@ class TensorArraySplitOp : public OpKernel { errors::InvalidArgument( "Expected lengths to be a vector, received shape: ", tensor_lengths->shape().DebugString())); - OP_REQUIRES(ctx, FastBoundsCheck(tensor_lengths->NumElements(), - std::numeric_limits::max()), + OP_REQUIRES(ctx, + FastBoundsCheck(tensor_lengths->NumElements(), + std::numeric_limits::max()), errors::InvalidArgument( "Expected lengths to have < max int32 entries")); diff --git a/tensorflow/core/kernels/tensor_dataset_op.cc b/tensorflow/core/kernels/tensor_dataset_op.cc index db7c9473287..1f690820316 100644 --- a/tensorflow/core/kernels/tensor_dataset_op.cc +++ b/tensorflow/core/kernels/tensor_dataset_op.cc @@ -78,7 +78,7 @@ class TensorDatasetOp : public DatasetOpKernel { components.emplace_back(node); } TF_RETURN_IF_ERROR( - b->AddDatasetWithInputAsList(this, components, output)); + b->AddDataset(this, {}, {std::make_pair(0, components)}, {}, output)); return Status::OK(); } diff --git a/tensorflow/core/kernels/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/tensor_slice_dataset_op.cc index fd36bf524ce..4d0cbdd67c3 100644 --- a/tensorflow/core/kernels/tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/tensor_slice_dataset_op.cc @@ -94,7 +94,7 @@ class TensorSliceDatasetOp : public DatasetOpKernel { components.emplace_back(node); } TF_RETURN_IF_ERROR( - b->AddDatasetWithInputAsList(this, components, output)); + b->AddDataset(this, {}, {std::make_pair(0, components)}, {}, output)); return Status::OK(); } diff --git a/tensorflow/core/kernels/zip_dataset_op.cc b/tensorflow/core/kernels/zip_dataset_op.cc index f466c8b268d..96080863ea1 100644 --- a/tensorflow/core/kernels/zip_dataset_op.cc +++ b/tensorflow/core/kernels/zip_dataset_op.cc @@ -78,17 +78,17 @@ class ZipDatasetOp : public DatasetOpKernel { string DebugString() override { return "ZipDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { std::vector input_graph_nodes; input_graph_nodes.reserve(inputs_.size()); for (const auto& input : inputs_) { Node* input_node; - TF_RETURN_IF_ERROR(b->AddParentDataset(input, &input_node)); + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input, &input_node)); input_graph_nodes.emplace_back(input_node); } - TF_RETURN_IF_ERROR( - b->AddDatasetWithInputAsList(this, input_graph_nodes, output)); + TF_RETURN_IF_ERROR(b->AddDataset( + this, {}, {std::make_pair(0, input_graph_nodes)}, {}, output)); return Status::OK(); } diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h index 251d58817e7..b89b74b8dec 100644 --- a/tensorflow/core/lib/core/threadpool.h +++ b/tensorflow/core/lib/core/threadpool.h @@ -30,7 +30,7 @@ class ThreadPool { // Constructs a pool that contains "num_threads" threads with specified // "name". env->StartThread() is used to create individual threads with the // given ThreadOptions. If "low_latency_hint" is true the thread pool - // implementation may use it as a hint that lower latency if preferred at the + // implementation may use it as a hint that lower latency is preferred at the // cost of higher CPU usage, e.g. by letting one or more idle threads spin // wait. Conversely, if the threadpool is used to schedule high-latency // operations like I/O the hint should be set to false. diff --git a/tensorflow/core/lib/strings/str_util.cc b/tensorflow/core/lib/strings/str_util.cc index 240e1454e58..d28857803d7 100644 --- a/tensorflow/core/lib/strings/str_util.cc +++ b/tensorflow/core/lib/strings/str_util.cc @@ -84,15 +84,32 @@ inline int hex_digit_to_int(char c) { return x & 0xf; } -bool CUnescapeInternal(StringPiece source, char* dest, +bool CUnescapeInternal(StringPiece source, string* dest, string::size_type* dest_len, string* error) { - char* d = dest; const char* p = source.data(); const char* end = source.end(); const char* last_byte = end - 1; + // We are going to write the result to dest with its iterator. If our string + // implementation uses copy-on-write, this will trigger a copy-on-write of + // dest's buffer; that is, dest will be assigned a new buffer. + // + // Note that the following way is NOT a legal way to modify a string's + // content: + // + // char* d = const_cast(dest->data()); + // + // This won't trigger copy-on-write of the string, and so is dangerous when + // the buffer is shared. + auto d = dest->begin(); + // Small optimization for case where source = dest and there's no escaping - while (p == d && p < end && *p != '\\') p++, d++; + if (source.data() == dest->data()) { + while (p < end && *p != '\\') { + p++; + d++; + } + } while (p < end) { if (*p != '\\') { @@ -192,7 +209,7 @@ bool CUnescapeInternal(StringPiece source, char* dest, p++; // read past letter we escaped } } - *dest_len = d - dest; + *dest_len = d - dest->begin(); return true; } @@ -215,8 +232,7 @@ bool SplitAndParseAsInts(StringPiece text, char delim, bool CUnescape(StringPiece source, string* dest, string* error) { dest->resize(source.size()); string::size_type dest_size; - if (!CUnescapeInternal(source, const_cast(dest->data()), &dest_size, - error)) { + if (!CUnescapeInternal(source, dest, &dest_size, error)) { return false; } dest->erase(dest_size); diff --git a/tensorflow/core/lib/strings/str_util_test.cc b/tensorflow/core/lib/strings/str_util_test.cc index 5c735a87a39..d5909d17aaa 100644 --- a/tensorflow/core/lib/strings/str_util_test.cc +++ b/tensorflow/core/lib/strings/str_util_test.cc @@ -43,6 +43,19 @@ TEST(CUnescape, Basic) { EXPECT_EQ("\320hi\200", ExpectCUnescapeSuccess("\\320hi\\200")); } +TEST(CUnescape, HandlesCopyOnWriteStrings) { + string dest = "hello"; + string read = dest; + // For std::string, read and dest now share the same buffer. + + string error; + StringPiece source = "llohe"; + // CUnescape is going to write "llohe" to dest, so dest's buffer will be + // reallocated, and read's buffer remains untouched. + EXPECT_TRUE(str_util::CUnescape(source, &dest, &error)); + EXPECT_EQ("hello", read); +} + TEST(StripTrailingWhitespace, Basic) { string test; test = "hello"; diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 6833c8e0ea3..ffb608d6007 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -40245,6 +40245,63 @@ op { } is_stateful: true } +op { + name: "TensorArrayV3" + input_arg { + name: "size" + type: DT_INT32 + } + output_arg { + name: "handle" + type: DT_RESOURCE + } + output_arg { + name: "flow" + type: DT_FLOAT + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "element_shape" + type: "shape" + default_value { + shape { + unknown_rank: true + } + } + } + attr { + name: "dynamic_size" + type: "bool" + default_value { + b: false + } + } + attr { + name: "clear_after_read" + type: "bool" + default_value { + b: true + } + } + attr { + name: "identical_element_shapes" + type: "bool" + default_value { + b: false + } + } + attr { + name: "tensor_array_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true +} op { name: "TensorArrayWrite" input_arg { diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index 3b1ed217ce1..ac2dc601f1f 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -1346,6 +1346,7 @@ REGISTER_OP("TensorArrayV3") .Attr("element_shape: shape = { unknown_rank: true }") .Attr("dynamic_size: bool = false") .Attr("clear_after_read: bool = true") + .Attr("identical_element_shapes: bool = false") .Attr("tensor_array_name: string = ''") .Output("handle: resource") .Output("flow: float") @@ -1374,6 +1375,12 @@ dynamic_size: A boolean that determines whether writes to the TensorArray clear_after_read: If true (default), Tensors in the TensorArray are cleared after being read. This disables multiple read semantics but allows early release of memory. +identical_element_shapes: If true (default is false), then all + elements in the TensorArray will be expected to have have identical shapes. + This allows certain behaviors, like dynamically checking for + consistent shapes on write, and being able to fill in properly + shaped zero tensors on stack -- even if the element_shape attribute + is not fully defined. tensor_array_name: Overrides the name used for the temporary tensor_array resource. Default value is the name of the 'TensorArray' op (which is guaranteed unique). diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index b7ce9dfce83..4ebb6aad3b9 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -31801,6 +31801,14 @@ op { } description: "If true (default), Tensors in the TensorArray are cleared\nafter being read. This disables multiple read semantics but allows early\nrelease of memory." } + attr { + name: "identical_element_shapes" + type: "bool" + default_value { + b: false + } + description: "If true (default is false), then all\nelements in the TensorArray will be expected to have have identical shapes.\nThis allows certain behaviors, like dynamically checking for\nconsistent shapes on write, and being able to fill in properly\nshaped zero tensors on stack -- even if the element_shape attribute\nis not fully defined." + } attr { name: "tensor_array_name" type: "string" diff --git a/tensorflow/core/ops/summary_ops.cc b/tensorflow/core/ops/summary_ops.cc index 5efbac7ad76..7f6d8b06cd3 100644 --- a/tensorflow/core/ops/summary_ops.cc +++ b/tensorflow/core/ops/summary_ops.cc @@ -256,4 +256,17 @@ sample_rate: The sample rate of the signal in hertz. max_outputs: Max number of batch elements to generate audio for. )doc"); +REGISTER_OP("WriteGraphSummary") + .Input("writer: resource") + .Input("global_step: int64") + .Input("tensor: string") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Writes a `GraphDef` protocol buffer to a `SummaryWriter`. + +writer: Handle of `SummaryWriter`. +global_step: The step to write the summary for. +tensor: A scalar string of the serialized tf.GraphDef proto. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD index 901fb79d6aa..624145da751 100644 --- a/tensorflow/core/platform/cloud/BUILD +++ b/tensorflow/core/platform/cloud/BUILD @@ -41,6 +41,17 @@ cc_library( deps = ["//tensorflow/core:lib"], ) +cc_library( + name = "gcs_dns_cache", + srcs = ["gcs_dns_cache.cc"], + hdrs = ["gcs_dns_cache.h"], + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":http_request", + "//tensorflow/core:lib", + ], +) + cc_library( name = "gcs_file_system", srcs = ["gcs_file_system.cc"], @@ -51,6 +62,7 @@ cc_library( ":curl_http_request", ":expiring_lru_cache", ":file_block_cache", + ":gcs_dns_cache", ":google_auth_provider", ":http_request", ":retrying_file_system", @@ -231,6 +243,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "gcs_dns_cache_test", + size = "small", + srcs = ["gcs_dns_cache_test.cc"], + deps = [ + ":gcs_dns_cache", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "curl_http_request_test", size = "small", diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc index e2d935f35eb..d01734ba3a6 100644 --- a/tensorflow/core/platform/cloud/curl_http_request.cc +++ b/tensorflow/core/platform/cloud/curl_http_request.cc @@ -131,6 +131,9 @@ CurlHttpRequest::~CurlHttpRequest() { if (curl_headers_) { libcurl_->curl_slist_free_all(curl_headers_); } + if (resolve_list_) { + libcurl_->curl_slist_free_all(resolve_list_); + } if (put_body_) { fclose(put_body_); } @@ -212,6 +215,17 @@ Status CurlHttpRequest::AddHeader(const string& name, const string& value) { return Status::OK(); } +Status CurlHttpRequest::AddResolveOverride(const string& hostname, int64 port, + const string& ip_addr) { + TF_RETURN_IF_ERROR(CheckInitialized()); + TF_RETURN_IF_ERROR(CheckNotSent()); + // Resolve values are hostname:port:IP.add.ress + resolve_list_ = libcurl_->curl_slist_append( + resolve_list_, + strings::StrCat(hostname, ":", port, ":", ip_addr).c_str()); + return Status::OK(); +} + Status CurlHttpRequest::AddAuthBearerHeader(const string& auth_token) { TF_RETURN_IF_ERROR(CheckInitialized()); TF_RETURN_IF_ERROR(CheckNotSent()); @@ -376,6 +390,9 @@ Status CurlHttpRequest::Send() { if (curl_headers_) { libcurl_->curl_easy_setopt(curl_, CURLOPT_HTTPHEADER, curl_headers_); } + if (resolve_list_) { + libcurl_->curl_easy_setopt(curl_, CURLOPT_RESOLVE, resolve_list_); + } libcurl_->curl_easy_setopt(curl_, CURLOPT_HEADERDATA, reinterpret_cast(this)); libcurl_->curl_easy_setopt(curl_, CURLOPT_HEADERFUNCTION, diff --git a/tensorflow/core/platform/cloud/curl_http_request.h b/tensorflow/core/platform/cloud/curl_http_request.h index c7a555de10c..2396593d6de 100644 --- a/tensorflow/core/platform/cloud/curl_http_request.h +++ b/tensorflow/core/platform/cloud/curl_http_request.h @@ -71,6 +71,9 @@ class CurlHttpRequest : public HttpRequest { /// Sets a request header. Status AddHeader(const string& name, const string& value) override; + Status AddResolveOverride(const string& hostname, int64 port, + const string& ip_addr) override; + /// Sets the 'Authorization' header to the value of 'Bearer ' + auth_token. Status AddAuthBearerHeader(const string& auth_token) override; @@ -146,6 +149,7 @@ class CurlHttpRequest : public HttpRequest { std::vector* response_buffer_ = nullptr; CURL* curl_ = nullptr; curl_slist* curl_headers_ = nullptr; + curl_slist* resolve_list_ = nullptr; std::vector default_response_buffer_; diff --git a/tensorflow/core/platform/cloud/gcs_dns_cache.cc b/tensorflow/core/platform/cloud/gcs_dns_cache.cc new file mode 100644 index 00000000000..63f2da065db --- /dev/null +++ b/tensorflow/core/platform/cloud/gcs_dns_cache.cc @@ -0,0 +1,135 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/cloud/gcs_dns_cache.h" + +#include +#include +#include + +namespace tensorflow { + +namespace { + +constexpr char kStorageHost[] = "storage.googleapis.com"; +constexpr char kWwwHost[] = "www.googleapis.com"; + +} // namespace + +GcsDnsCache::GcsDnsCache(Env* env, int64 refresh_rate_secs) + : env_(env), refresh_rate_secs_(refresh_rate_secs) {} + +Status GcsDnsCache::AnnotateRequest(HttpRequest* request) { + // TODO(saeta): Blacklist failing IP addresses. + mutex_lock l(mu_); + if (!started_) { + DCHECK(!worker_) << "Worker thread already exists!"; + // Perform DNS resolutions to warm the cache. + std::vector www_addresses = ResolveName(kWwwHost); + std::vector storage_addresses = ResolveName(kStorageHost); + www_addresses.swap(www_addresses_); + storage_addresses.swap(storage_addresses_); + + // Note: we opt to use a thread instead of a delayed closure. + worker_.reset(env_->StartThread( + {}, "gcs_dns_worker", std::bind(&GcsDnsCache::WorkerThread, this))); + started_ = true; + } + if (!storage_addresses_.empty()) { + std::uniform_int_distribution<> storage_dist(0, + storage_addresses_.size() - 1); + size_t index = storage_dist(random_); + TF_RETURN_IF_ERROR(request->AddResolveOverride(kStorageHost, 443, + storage_addresses_[index])); + } else { + LOG(WARNING) << "No IP addresses available for " << kStorageHost; + } + if (!www_addresses_.empty()) { + std::uniform_int_distribution<> www_dist(0, www_addresses_.size() - 1); + size_t index = www_dist(random_); + TF_RETURN_IF_ERROR( + request->AddResolveOverride(kWwwHost, 443, www_addresses_[index])); + } else { + LOG(WARNING) << "No IP addresses available for " << kWwwHost; + } + return Status::OK(); +} + +/* static */ std::vector GcsDnsCache::ResolveName(const string& name) { + addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET; // Only use IPv4 for now. + hints.ai_socktype = SOCK_STREAM; + addrinfo* result = nullptr; + int return_code = getaddrinfo(name.c_str(), nullptr, &hints, &result); + + std::vector output; + if (return_code == 0) { + for (addrinfo* i = result; i != nullptr; i = i->ai_next) { + if (i->ai_family != AF_INET || i->ai_addr->sa_family != AF_INET) { + LOG(WARNING) << "Non-IPv4 address returned. ai_family: " << i->ai_family + << ". sa_family: " << i->ai_addr->sa_family << "."; + continue; + } + char buf[INET_ADDRSTRLEN]; + void* address_ptr = + &(reinterpret_cast(i->ai_addr)->sin_addr); + const char* formatted = nullptr; + if ((formatted = inet_ntop(i->ai_addr->sa_family, address_ptr, buf, + INET_ADDRSTRLEN)) == nullptr) { + LOG(ERROR) << "Error converting response to IP address for " << name + << ": " << strerror(errno); + } else { + output.emplace_back(buf); + } + } + } else { + if (return_code == EAI_SYSTEM) { + LOG(ERROR) << "Error resolving " << name + << " (EAI_SYSTEM): " << strerror(errno); + } else { + LOG(ERROR) << "Error resolving " << name << ": " + << gai_strerror(return_code); + } + } + if (result != nullptr) { + freeaddrinfo(result); + } + return output; +} + +void GcsDnsCache::WorkerThread() { + while (true) { + { + // Don't immediately re-resolve the addresses. + mutex_lock l(mu_); + if (cancelled_) return; + cond_var_.wait_for(l, std::chrono::seconds(refresh_rate_secs_)); + if (cancelled_) return; + } + // Resolve DNS values + std::vector www_addresses = ResolveName(kWwwHost); + std::vector storage_addresses = ResolveName(kStorageHost); + + { + mutex_lock l(mu_); + // Update instance variables. + www_addresses.swap(www_addresses_); + storage_addresses.swap(storage_addresses_); + } + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/gcs_dns_cache.h b/tensorflow/core/platform/cloud/gcs_dns_cache.h new file mode 100644 index 00000000000..7a4d3847a5a --- /dev/null +++ b/tensorflow/core/platform/cloud/gcs_dns_cache.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_ +#define THIRD_PARTY_TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_ + +#include + +#include "tensorflow/core/platform/cloud/http_request.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { +const int64 kDefaultRefreshRateSecs = 60; + +// DnsCache is a userspace DNS cache specialized for the GCS filesystem. +// +// Some environments have unreliable DNS resolvers. DnsCache ameliorates the +// situation by radically reducing the number of DNS requests by performing +// 2 DNS queries per minute (by default) on a background thread. Updated cache +// entries are used to override curl's DNS resolution processes. +class GcsDnsCache { + public: + // Default no-argument constructor. + GcsDnsCache() : GcsDnsCache(kDefaultRefreshRateSecs) {} + + // Constructs a GcsDnsCache with the specified refresh rate. + GcsDnsCache(int64 refresh_rate_secs) + : GcsDnsCache(Env::Default(), refresh_rate_secs) {} + + GcsDnsCache(Env* env, int64 refresh_rate_secs); + + ~GcsDnsCache() { + mutex_lock l(mu_); + cancelled_ = true; + cond_var_.notify_one(); + } + + // Annotate the given HttpRequest with resolve overrides from the cache. + Status AnnotateRequest(HttpRequest* request); + + private: + static std::vector ResolveName(const string& name); + void WorkerThread(); + + // Define a friend class for testing. + friend class GcsDnsCacheTest; + + mutex mu_; + Env* env_; + condition_variable cond_var_; + std::default_random_engine random_ GUARDED_BY(mu_); + bool started_ GUARDED_BY(mu_) = false; + bool cancelled_ GUARDED_BY(mu_) = false; + std::vector www_addresses_ GUARDED_BY(mu_); + std::vector storage_addresses_ GUARDED_BY(mu_); + std::unique_ptr worker_ GUARDED_BY(mu_); // After mutable vars. + const int64 refresh_rate_secs_; +}; + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_ diff --git a/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc b/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc new file mode 100644 index 00000000000..cba6caff22e --- /dev/null +++ b/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc @@ -0,0 +1,113 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/cloud/gcs_dns_cache.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +class TestHttpRequest : public HttpRequest { + public: + Status Init() override { return Status::OK(); } + Status SetUri(const string& uri) override { return Status::OK(); } + Status SetRange(uint64 start, uint64 end) override { return Status::OK(); } + Status AddHeader(const string& name, const string& value) override { + return Status::OK(); + } + Status AddResolveOverride(const string& hostname, int64 port, + const string& ip_addr) override { + EXPECT_EQ(port, 443) << "Unexpected port set for hostname: " << hostname; + auto itr = resolve_overrides_.find(hostname); + EXPECT_EQ(itr, resolve_overrides_.end()) + << "Hostname " << hostname << "already in map: " << itr->second; + + resolve_overrides_.insert( + std::map::value_type(hostname, ip_addr)); + return Status::OK(); + } + + Status AddAuthBearerHeader(const string& auth_token) override { + return Status::OK(); + } + + Status SetDeleteRequest() override { return Status::OK(); } + + Status SetPutFromFile(const string& body_filepath, size_t offset) override { + return Status::OK(); + } + Status SetPutEmptyBody() override { return Status::OK(); } + + Status SetPostFromBuffer(const char* buffer, size_t size) override { + return Status::OK(); + } + Status SetPostEmptyBody() override { return Status::OK(); } + + Status SetResultBuffer(std::vector* out_buffer) override { + return Status::OK(); + } + + string GetResponseHeader(const string& name) const override { return ""; } + uint64 GetResponseCode() const override { return 0; } + Status Send() override { return Status::OK(); } + string EscapeString(const string& str) override { return ""; } + + std::map resolve_overrides_; +}; + +// Friend class for testing. +// +// It is written this way (as opposed to using FRIEND_TEST) to avoid a +// non-test-time dependency on gunit. +class GcsDnsCacheTest : public ::testing::Test { + protected: + void ResolveNameTest() { + auto response = GcsDnsCache::ResolveName("www.googleapis.com"); + EXPECT_LT(1, response.size()) << str_util::Join(response, ", "); + } + + void AnnotateRequestTest() { + GcsDnsCache d; + { + mutex_lock l(d.mu_); + d.started_ = true; // Avoid creating a thread. + d.www_addresses_ = {"192.168.1.1"}; + d.storage_addresses_ = {"172.134.1.1"}; + } + + TestHttpRequest req; + Status s = d.AnnotateRequest(&req); + EXPECT_TRUE(s.ok()) << s; + EXPECT_EQ("192.168.1.1", req.resolve_overrides_["www.googleapis.com"]); + EXPECT_EQ("172.134.1.1", req.resolve_overrides_["storage.googleapis.com"]); + } + + void SuccessfulCleanupTest() { + // Create a DnsCache object, start the worker thread, ensure it cleans up in + // a timely manner. + GcsDnsCache d; + TestHttpRequest req; + Status s = d.AnnotateRequest(&req); + EXPECT_TRUE(s.ok()) << s; + } +}; + +TEST_F(GcsDnsCacheTest, ResolveName) { ResolveNameTest(); } + +TEST_F(GcsDnsCacheTest, AnnotateRequest) { AnnotateRequestTest(); } + +TEST_F(GcsDnsCacheTest, SuccessfulCleanup) { SuccessfulCleanupTest(); } + +} // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 17fe704b79a..9287de7237d 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -89,6 +89,10 @@ constexpr char kMatchingPathsCacheMaxEntries[] = constexpr size_t kMatchingPathsCacheDefaultMaxEntries = 1024; // The file statistics returned by Stat() for directories. const FileStatistics DIRECTORY_STAT(0, 0, true); +// Some environments exhibit unreliable DNS resolution. Set this environment +// variable to a positive integer describing the frequency used to refresh the +// userspace DNS cache. +constexpr char kResolveCacheSecs[] = "GCS_RESOLVE_REFRESH_SECS"; Status GetTmpFilename(string* filename) { if (!filename) { @@ -434,8 +438,8 @@ class GcsWritableFile : public WritableFile { std::unique_ptr request(http_request_factory_->Create()); TF_RETURN_IF_ERROR(request->Init()); TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat( - kGcsUploadUriBase, "b/", bucket_, "/o?uploadType=resumable&name=", - request->EscapeString(object_)))); + kGcsUploadUriBase, "b/", bucket_, + "/o?uploadType=resumable&name=", request->EscapeString(object_)))); TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); TF_RETURN_IF_ERROR(request->AddHeader("X-Upload-Content-Length", std::to_string(file_size))); @@ -624,6 +628,12 @@ GcsFileSystem::GcsFileSystem() } matching_paths_cache_.reset(new ExpiringLRUCache>( matching_paths_cache_max_age, matching_paths_cache_max_entries)); + + int64 resolve_frequency_secs; + if (GetEnvVar(kResolveCacheSecs, strings::safe_strto64, + &resolve_frequency_secs)) { + dns_cache_.reset(new GcsDnsCache(resolve_frequency_secs)); + } } GcsFileSystem::GcsFileSystem( @@ -678,6 +688,11 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset, TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); TF_RETURN_IF_ERROR(request->SetRange(offset, offset + n - 1)); TF_RETURN_IF_ERROR(request->SetResultBuffer(out)); + + if (dns_cache_) { + TF_RETURN_IF_ERROR(dns_cache_->AnnotateRequest(request.get())); + } + TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading gs://", bucket, "/", object); return Status::OK(); @@ -821,6 +836,11 @@ Status GcsFileSystem::StatForObject(const string& fname, const string& bucket, "?fields=size%2Cupdated"))); TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); TF_RETURN_IF_ERROR(request->SetResultBuffer(&output_buffer)); + + if (dns_cache_) { + TF_RETURN_IF_ERROR(dns_cache_->AnnotateRequest(request.get())); + } + TF_RETURN_WITH_CONTEXT_IF_ERROR( request->Send(), " when reading metadata of gs://", bucket, "/", object); @@ -959,12 +979,12 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname, uri = strings::StrCat(uri, "&delimiter=%2F"); } if (!object_prefix.empty()) { - uri = strings::StrCat(uri, "&prefix=", - request->EscapeString(object_prefix)); + uri = strings::StrCat(uri, + "&prefix=", request->EscapeString(object_prefix)); } if (!nextPageToken.empty()) { - uri = strings::StrCat(uri, "&pageToken=", - request->EscapeString(nextPageToken)); + uri = strings::StrCat( + uri, "&pageToken=", request->EscapeString(nextPageToken)); } if (max_results - retrieved_results < kGetChildrenDefaultPageSize) { uri = @@ -973,6 +993,11 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname, TF_RETURN_IF_ERROR(request->SetUri(uri)); TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); TF_RETURN_IF_ERROR(request->SetResultBuffer(&output_buffer)); + + if (dns_cache_) { + TF_RETURN_IF_ERROR(dns_cache_->AnnotateRequest(request.get())); + } + TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading ", dirname); Json::Value root; StringPiece response_piece = diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h index 36a1d42fdef..4b4853c838a 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.h +++ b/tensorflow/core/platform/cloud/gcs_file_system.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/platform/cloud/auth_provider.h" #include "tensorflow/core/platform/cloud/expiring_lru_cache.h" #include "tensorflow/core/platform/cloud/file_block_cache.h" +#include "tensorflow/core/platform/cloud/gcs_dns_cache.h" #include "tensorflow/core/platform/cloud/http_request.h" #include "tensorflow/core/platform/cloud/retrying_file_system.h" #include "tensorflow/core/platform/file_system.h" @@ -141,6 +142,7 @@ class GcsFileSystem : public FileSystem { std::unique_ptr auth_provider_; std::unique_ptr http_request_factory_; std::unique_ptr file_block_cache_; + std::unique_ptr dns_cache_; using StatCache = ExpiringLRUCache; std::unique_ptr stat_cache_; diff --git a/tensorflow/core/platform/cloud/http_request.h b/tensorflow/core/platform/cloud/http_request.h index 8182b63d5b2..02d9e9054ad 100644 --- a/tensorflow/core/platform/cloud/http_request.h +++ b/tensorflow/core/platform/cloud/http_request.h @@ -64,6 +64,14 @@ class HttpRequest { /// Sets a request header. virtual Status AddHeader(const string& name, const string& value) = 0; + /// Sets a DNS resolve mapping (to skip DNS resolution). + /// + /// Note: because GCS is available over HTTPS, we cannot replace the hostname + /// in the URI with an IP address, as that will cause the certificate check + /// to fail. + virtual Status AddResolveOverride(const string& hostname, int64 port, + const string& ip_addr) = 0; + /// Sets the 'Authorization' header to the value of 'Bearer ' + auth_token. virtual Status AddAuthBearerHeader(const string& auth_token) = 0; diff --git a/tensorflow/core/profiler/g3doc/profiler_ui.jpg b/tensorflow/core/profiler/g3doc/profiler_ui.jpg index 36aa94502a8..77346e61ae9 100644 Binary files a/tensorflow/core/profiler/g3doc/profiler_ui.jpg and b/tensorflow/core/profiler/g3doc/profiler_ui.jpg differ diff --git a/tensorflow/core/profiler/internal/tfprof_op.cc b/tensorflow/core/profiler/internal/tfprof_op.cc index c04b0ea0c62..5a8429d4893 100644 --- a/tensorflow/core/profiler/internal/tfprof_op.cc +++ b/tensorflow/core/profiler/internal/tfprof_op.cc @@ -109,7 +109,6 @@ const ShowMultiNode* TFOp::ShowInternal(const Options& opts, fprintf(stderr, "Only 'code' view supports pprof output now.\n"); return root_.get(); } - if (opts.output_type == kOutput[1] || opts.output_type == kOutput[2]) { root_->formatted_str = FormatNode(root_.get(), root_.get(), opts); } @@ -130,7 +129,6 @@ const ShowMultiNode* TFOp::ShowInternal(const Options& opts, nodes.push_back(n.second.get()); } nodes = SortNodes(nodes, opts); - // pre keeps track of previous visited node. OpNode* pre = nullptr; std::vector account_nodes; @@ -166,10 +164,6 @@ const ShowMultiNode* TFOp::ShowInternal(const Options& opts, (*it)->AddSelfToTotalStats(); if (pre) (*it)->AggregateTotalStats(pre); } - if (pre) { - (*it)->mutable_proto()->add_children()->MergeFrom(pre->proto()); - pre->mutable_proto()->clear_children(); - } pre = *it; } if (opts.account_displayed_op_only) { @@ -178,11 +172,6 @@ const ShowMultiNode* TFOp::ShowInternal(const Options& opts, root_->AggregateTotalStats(pre); } } - if (pre) { - root_->mutable_proto()->add_children()->MergeFrom(pre->proto()); - pre->mutable_proto()->clear_children(); - } - if (opts.output_type == kOutput[1] || opts.output_type == kOutput[2]) { string display_str = FormatLegend(opts); for (OpNode* node : show_nodes) { @@ -192,6 +181,13 @@ const ShowMultiNode* TFOp::ShowInternal(const Options& opts, // TODO(xpan): Is it the right choice? root_->formatted_str = display_str; } + // Populate the chidren field. + auto* pre_pb = root_->mutable_proto(); + for (auto& show_node : show_nodes) { + pre_pb->clear_children(); + pre_pb->add_children()->Swap(show_node->mutable_proto()); + pre_pb = pre_pb->mutable_children(0); + } return root_.get(); } diff --git a/tensorflow/core/profiler/profiler.cc b/tensorflow/core/profiler/profiler.cc index a5e513aa21c..b280242df18 100644 --- a/tensorflow/core/profiler/profiler.cc +++ b/tensorflow/core/profiler/profiler.cc @@ -266,7 +266,18 @@ int Run(int argc, char** argv) { linenoiseSetCompletionCallback(completion); linenoiseHistoryLoad(".tfprof_history.txt"); - for (char* line = nullptr; (line = linenoise("tfprof> ")) != nullptr;) { + bool looped = false; + while (true) { + char* line = linenoise("tfprof> "); + if (line == nullptr) { + if (!looped) { + fprintf(stderr, + "Cannot start interative shell, " + "use 'bazel-bin' instead of 'bazel run'.\n"); + } + break; + } + looped = true; string line_s = line; free(line); diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index 8f3457e97ce..3b5d1563a26 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -30,11 +30,13 @@ message RewriterConfig { } // Optimize tensor layouts - bool optimize_tensor_layout = 1; + Toggle layout_optimizer = 1; // Fold constants (default is ON) Toggle constant_folding = 3; // Arithmetic optimizations (default is ON) Toggle arithmetic_optimization = 7; + // Control dependency optimizations (default is OFF). + Toggle dependency_optimization = 8; // If true, don't remove unnecessary ops from the graph bool disable_model_pruning = 2; diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto index 34a5cff3660..e7b3f36fcc7 100644 --- a/tensorflow/core/protobuf/worker.proto +++ b/tensorflow/core/protobuf/worker.proto @@ -64,6 +64,22 @@ message CreateWorkerSessionRequest { message CreateWorkerSessionResponse { } +//////////////////////////////////////////////////////////////////////////////// +// +// DeleteSession method request/response messages +// +// Deletes all worker-side state associated with the given session handle. +// +//////////////////////////////////////////////////////////////////////////////// + +message DeleteWorkerSessionRequest { + // Sessions are identified by a given handle. + string session_handle = 1; +} + +message DeleteWorkerSessionResponse { +} + //////////////////////////////////////////////////////////////////////////////// // // RegisterGraph method request/response messages diff --git a/tensorflow/core/protobuf/worker_service.proto b/tensorflow/core/protobuf/worker_service.proto index 3de9e48b78e..e1bfb04d7c5 100644 --- a/tensorflow/core/protobuf/worker_service.proto +++ b/tensorflow/core/protobuf/worker_service.proto @@ -43,6 +43,10 @@ service WorkerService { rpc CreateWorkerSession(CreateWorkerSessionRequest) returns (CreateWorkerSessionResponse); + // See worker.proto for details. + rpc DeleteWorkerSession(DeleteWorkerSessionRequest) + returns (DeleteWorkerSessionResponse); + // See worker.proto for details. rpc RegisterGraph(RegisterGraphRequest) returns (RegisterGraphResponse); diff --git a/tensorflow/docs_src/mobile/index.md b/tensorflow/docs_src/mobile/index.md index a10db74364b..6bcd7d09d9c 100644 --- a/tensorflow/docs_src/mobile/index.md +++ b/tensorflow/docs_src/mobile/index.md @@ -2,8 +2,8 @@ TensorFlow was designed to be a good deep learning solution for mobile platforms. Currently we have two solutions for deploying machine learning -applications on mobile and embedded devices: @{$mobile/mobile_intro$TensorFlow -for Mobile} and @{$mobile/tflite$TensorFlow Lite}. +applications on mobile and embedded devices: +@{$mobile/mobile_intro$TensorFlow for Mobile} and @{$mobile/tflite$TensorFlow Lite}. ## TensorFlow Lite versus TensorFlow Mobile diff --git a/tensorflow/docs_src/tutorials/linear.md b/tensorflow/docs_src/tutorials/linear.md index a6517549c36..d333d012790 100644 --- a/tensorflow/docs_src/tutorials/linear.md +++ b/tensorflow/docs_src/tutorials/linear.md @@ -175,7 +175,7 @@ the name of a `FeatureColumn`. Each key's value is a tensor containing the values of that feature for all data instances. See @{$input_fn$Building Input Functions with tf.estimator} for a more comprehensive look at input functions, and `input_fn` in the -[linear models tutorial code](https://www.tensorflow.org/code/tensorflow/examples/learn/wide_n_deep_tutorial.py) +[linear models tutorial code](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py) for an example implementation of an input function. The input function is passed to the `train()` and `evaluate()` calls that diff --git a/tensorflow/examples/learn/BUILD b/tensorflow/examples/learn/BUILD index 23a42a60ba4..aba7f600b53 100644 --- a/tensorflow/examples/learn/BUILD +++ b/tensorflow/examples/learn/BUILD @@ -113,13 +113,6 @@ py_binary( ], ) -py_binary( - name = "wide_n_deep_tutorial", - srcs = ["wide_n_deep_tutorial.py"], - srcs_version = "PY2AND3", - deps = ["//tensorflow:tensorflow_py"], -) - py_binary( name = "mnist", srcs = ["mnist.py"], @@ -153,7 +146,6 @@ sh_test( ":text_classification_character_cnn", ":text_classification_character_rnn", ":text_classification_cnn", - ":wide_n_deep_tutorial", ], tags = [ "manual", diff --git a/tensorflow/examples/learn/README.md b/tensorflow/examples/learn/README.md index 70d9db85ee5..b74a8f39d98 100644 --- a/tensorflow/examples/learn/README.md +++ b/tensorflow/examples/learn/README.md @@ -23,7 +23,7 @@ processing (`pip install -U pandas`). ## Specialized Models * [Building a Random Forest Model](https://www.tensorflow.org/code/tensorflow/examples/learn/random_forest_mnist.py) -* [Building a Wide & Deep Model](https://www.tensorflow.org/code/tensorflow/examples/learn/wide_n_deep_tutorial.py) +* [Building a Wide & Deep Model](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py) * [Building a Residual Network Model](https://www.tensorflow.org/code/tensorflow/examples/learn/resnet.py) ## Text classification diff --git a/tensorflow/examples/learn/examples_test.sh b/tensorflow/examples/learn/examples_test.sh index b8763de471c..ef5e8a5de25 100755 --- a/tensorflow/examples/learn/examples_test.sh +++ b/tensorflow/examples/learn/examples_test.sh @@ -56,4 +56,3 @@ test text_classification_builtin_rnn_model --test_with_fake_data test text_classification_character_cnn --test_with_fake_data test text_classification_character_rnn --test_with_fake_data test text_classification_cnn --test_with_fake_data -test wide_n_deep_tutorial diff --git a/tensorflow/examples/learn/wide_n_deep_tutorial.py b/tensorflow/examples/learn/wide_n_deep_tutorial.py deleted file mode 100644 index 072353392a9..00000000000 --- a/tensorflow/examples/learn/wide_n_deep_tutorial.py +++ /dev/null @@ -1,252 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Example code for TensorFlow Wide & Deep Tutorial using TF High Level API. - -This example uses APIs in Tensorflow 1.4 or above. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import shutil -import sys -import tempfile - -import pandas as pd -from six.moves import urllib -import tensorflow as tf - - -CSV_COLUMNS = [ - "age", "workclass", "fnlwgt", "education", "education_num", - "marital_status", "occupation", "relationship", "race", "gender", - "capital_gain", "capital_loss", "hours_per_week", "native_country", - "income_bracket" -] - -gender = tf.feature_column.categorical_column_with_vocabulary_list( - "gender", ["Female", "Male"]) -education = tf.feature_column.categorical_column_with_vocabulary_list( - "education", [ - "Bachelors", "HS-grad", "11th", "Masters", "9th", - "Some-college", "Assoc-acdm", "Assoc-voc", "7th-8th", - "Doctorate", "Prof-school", "5th-6th", "10th", "1st-4th", - "Preschool", "12th" - ]) -marital_status = tf.feature_column.categorical_column_with_vocabulary_list( - "marital_status", [ - "Married-civ-spouse", "Divorced", "Married-spouse-absent", - "Never-married", "Separated", "Married-AF-spouse", "Widowed" - ]) -relationship = tf.feature_column.categorical_column_with_vocabulary_list( - "relationship", [ - "Husband", "Not-in-family", "Wife", "Own-child", "Unmarried", - "Other-relative" - ]) -workclass = tf.feature_column.categorical_column_with_vocabulary_list( - "workclass", [ - "Self-emp-not-inc", "Private", "State-gov", "Federal-gov", - "Local-gov", "?", "Self-emp-inc", "Without-pay", "Never-worked" - ]) - -# To show an example of hashing: -occupation = tf.feature_column.categorical_column_with_hash_bucket( - "occupation", hash_bucket_size=1000) -native_country = tf.feature_column.categorical_column_with_hash_bucket( - "native_country", hash_bucket_size=1000) - -# Continuous base columns. -age = tf.feature_column.numeric_column("age") -education_num = tf.feature_column.numeric_column("education_num") -capital_gain = tf.feature_column.numeric_column("capital_gain") -capital_loss = tf.feature_column.numeric_column("capital_loss") -hours_per_week = tf.feature_column.numeric_column("hours_per_week") - -# Transformations. -age_buckets = tf.feature_column.bucketized_column( - age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65]) - -# Wide columns and deep columns. -base_columns = [ - gender, education, marital_status, relationship, workclass, occupation, - native_country, age_buckets, -] - -crossed_columns = [ - tf.feature_column.crossed_column( - ["education", "occupation"], hash_bucket_size=1000), - tf.feature_column.crossed_column( - [age_buckets, "education", "occupation"], hash_bucket_size=1000), - tf.feature_column.crossed_column( - ["native_country", "occupation"], hash_bucket_size=1000) -] - -deep_columns = [ - tf.feature_column.indicator_column(workclass), - tf.feature_column.indicator_column(education), - tf.feature_column.indicator_column(gender), - tf.feature_column.indicator_column(relationship), - # To show an example of embedding - tf.feature_column.embedding_column(native_country, dimension=8), - tf.feature_column.embedding_column(occupation, dimension=8), - age, - education_num, - capital_gain, - capital_loss, - hours_per_week, -] - - -FLAGS = None - - -def maybe_download(train_data, test_data): - """Maybe downloads training data and returns train and test file names.""" - if train_data: - train_file_name = train_data - else: - train_file = tempfile.NamedTemporaryFile(delete=False) - urllib.request.urlretrieve( - "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data", - train_file.name) # pylint: disable=line-too-long - train_file_name = train_file.name - train_file.close() - print("Training data is downloaded to %s" % train_file_name) - - if test_data: - test_file_name = test_data - else: - test_file = tempfile.NamedTemporaryFile(delete=False) - urllib.request.urlretrieve( - "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test", - test_file.name) # pylint: disable=line-too-long - test_file_name = test_file.name - test_file.close() - print("Test data is downloaded to %s"% test_file_name) - - return train_file_name, test_file_name - - -def build_estimator(model_dir, model_type): - """Build an estimator.""" - if model_type == "wide": - m = tf.estimator.LinearClassifier( - model_dir=model_dir, feature_columns=base_columns + crossed_columns) - elif model_type == "deep": - m = tf.estimator.DNNClassifier( - model_dir=model_dir, - feature_columns=deep_columns, - hidden_units=[100, 50]) - else: - m = tf.estimator.DNNLinearCombinedClassifier( - model_dir=model_dir, - linear_feature_columns=crossed_columns, - dnn_feature_columns=deep_columns, - dnn_hidden_units=[100, 50]) - return m - - -def input_fn(data_file, num_epochs, shuffle): - """Returns an `input_fn` required by Estimator train/evaluate. - - Args: - data_file: The file path to the dataset. - num_epochs: Number of epochs to iterate over data. If `None`, `input_fn` - will generate infinite stream of data. - shuffle: bool, whether to read the data in random order. - """ - df_data = pd.read_csv( - tf.gfile.Open(data_file), - names=CSV_COLUMNS, - skipinitialspace=True, - engine="python", - skiprows=1) - # remove NaN elements - df_data = df_data.dropna(how="any", axis=0) - labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int) - - return tf.estimator.inputs.pandas_input_fn( - x=df_data, - y=labels, - batch_size=100, - num_epochs=num_epochs, - shuffle=shuffle, - num_threads=1) - - -def main(_): - tf.logging.set_verbosity(tf.logging.INFO) - - train_file_name, test_file_name = maybe_download(FLAGS.train_data, - FLAGS.test_data) - - # Specify file path below if want to find the output easily - model_dir = FLAGS.model_dir if FLAGS.model_dir else tempfile.mkdtemp() - - estimator = build_estimator(model_dir, FLAGS.model_type) - - # `tf.estimator.TrainSpec`, `tf.estimator.EvalSpec`, and - # `tf.estimator.train_and_evaluate` API are available in TF 1.4. - train_spec = tf.estimator.TrainSpec( - input_fn=input_fn(train_file_name, num_epochs=None, shuffle=True), - max_steps=FLAGS.train_steps) - - eval_spec = tf.estimator.EvalSpec( - input_fn=input_fn(test_file_name, num_epochs=1, shuffle=False), - # set steps to None to run evaluation until all data consumed. - steps=None) - - tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) - - # Manual cleanup - shutil.rmtree(model_dir) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.register("type", "bool", lambda v: v.lower() == "true") - parser.add_argument( - "--model_dir", - type=str, - default="", - help="Base directory for output models." - ) - parser.add_argument( - "--model_type", - type=str, - default="wide_n_deep", - help="Valid model types: {'wide', 'deep', 'wide_n_deep'}." - ) - parser.add_argument( - "--train_steps", - type=int, - default=2000, - help="Number of training steps." - ) - parser.add_argument( - "--train_data", - type=str, - default="", - help="Path to the training data." - ) - parser.add_argument( - "--test_data", - type=str, - default="", - help="Path to the test data." - ) - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 5a6ae4fa5ff..869213eb172 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -14797,6 +14797,21 @@ func TensorArrayV3ClearAfterRead(value bool) TensorArrayV3Attr { } } +// TensorArrayV3IdenticalElementShapes sets the optional identical_element_shapes attribute to value. +// +// value: If true (default is false), then all +// elements in the TensorArray will be expected to have have identical shapes. +// This allows certain behaviors, like dynamically checking for +// consistent shapes on write, and being able to fill in properly +// shaped zero tensors on stack -- even if the element_shape attribute +// is not fully defined. +// If not specified, defaults to false +func TensorArrayV3IdenticalElementShapes(value bool) TensorArrayV3Attr { + return func(m optionalAttr) { + m["identical_element_shapes"] = value + } +} + // TensorArrayV3TensorArrayName sets the optional tensor_array_name attribute to value. // // value: Overrides the name used for the temporary tensor_array @@ -20553,6 +20568,27 @@ func Sub(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } +// Writes a `GraphDef` protocol buffer to a `SummaryWriter`. +// +// Arguments: +// writer: Handle of `SummaryWriter`. +// global_step: The step to write the summary for. +// tensor: A scalar string of the serialized tf.GraphDef proto. +// +// Returns the created operation. +func WriteGraphSummary(scope *Scope, writer tf.Output, global_step tf.Output, tensor tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "WriteGraphSummary", + Input: []tf.Input{ + writer, global_step, tensor, + }, + } + return scope.AddOperation(opspec) +} + // MaxPool3DGradGradAttr is an optional argument to MaxPool3DGradGrad. type MaxPool3DGradGradAttr func(optionalAttr) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index f4dd565fc34..406ff30cebc 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -448,6 +448,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:op_gen_lib", "//tensorflow/core:protos_all_cc", "//tensorflow/python/eager:python_eager_op_gen", ], @@ -3853,15 +3854,15 @@ py_library( deps = [ ":array_ops", ":control_flow_ops", - ":framework", ":framework_for_generated_wrappers", - ":init_ops", + ":platform", + ":tensor_util", ":util", ":variable_scope", ":variables", + "//tensorflow/python/eager:context", "//tensorflow/python/estimator:util", "//third_party/py/numpy", - "@six_archive//:six", ], ) @@ -3872,12 +3873,14 @@ py_library( "layers/core.py", "layers/layers.py", "layers/maxout.py", + "layers/network.py", "layers/normalization.py", "layers/pooling.py", ], srcs_version = "PY2AND3", deps = [ ":array_ops", + ":array_ops_gen", ":control_flow_ops", ":framework", ":framework_for_generated_wrappers", @@ -3885,12 +3888,18 @@ py_library( ":layers_base", ":math_ops", ":nn", + ":nn_ops", + ":platform", + ":resource_variable_ops", + ":resource_variable_ops_gen", ":standard_ops", + ":state_ops", ":training", ":util", ":variable_scope", ":variables", "//tensorflow/python/eager:context", + "//tensorflow/python/estimator:util", "//third_party/py/numpy", "@six_archive//:six", ], @@ -3903,14 +3912,36 @@ py_test( main = "layers/base_test.py", srcs_version = "PY2AND3", deps = [ + ":array_ops", ":client_testlib", ":framework_for_generated_wrappers", ":framework_test_lib", ":init_ops", ":layers", + ":layers_base", ":math_ops", ":random_ops", ":variable_scope", + "//tensorflow/python/eager:context", + ], +) + +py_test( + name = "layers_network_test", + size = "small", + srcs = ["layers/network_test.py"], + main = "layers/network_test.py", + srcs_version = "PY2AND3", + deps = [ + ":array_ops", + ":client_testlib", + ":framework_for_generated_wrappers", + ":framework_test_lib", + ":layers", + ":layers_base", + ":sparse_ops", + "//tensorflow/python/eager:context", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index c36647b21c4..b491a637bac 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -61,7 +61,6 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ ":context", - ":memory_trace", "//tensorflow/python:errors", "//tensorflow/python:pywrap_tensorflow", ], @@ -88,12 +87,6 @@ py_library( visibility = ["//tensorflow:internal"], ) -py_library( - name = "memory_trace", - srcs = ["memory_trace.py"], - srcs_version = "PY2AND3", -) - cuda_py_test( name = "tensor_test", srcs = ["tensor_test.py"], @@ -222,6 +215,7 @@ cc_library( ":python_eager_op_gen", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:op_gen_lib", "//tensorflow/core:protos_all_cc", ], ) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 33601a1edcc..25f7ae785e6 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -305,6 +305,7 @@ def implicit_val_and_grad(f): is not known ahead of time. Example: + ```python dense_layer = tf.layers.Dense(1) def loss(x, y): @@ -378,6 +379,7 @@ def implicit_grad(f): is not known ahead of time. Example: + ```python dense_layer = tf.layers.Dense(1) def loss(x, y): @@ -733,20 +735,28 @@ _last_shape_dtype = [None, None] _last_zero = [None] +def _fast_fill(value, shape, dtype): + return array_ops.fill(shape, constant_op.constant(value, dtype=dtype)) + + def _zeros(shape, dtype): """Wraps array_ops.zeros to cache last zero for a given shape and dtype.""" if [shape, dtype] != _last_shape_dtype: _last_shape_dtype[:] = [shape, dtype] - _last_zero[0] = array_ops.zeros(shape, dtype) + _last_zero[0] = _fast_fill(0, shape, dtype) return _last_zero[0] +def _ones(shape, dtype): + return _fast_fill(1, shape, dtype) + + _default_vspace = imperative_grad.VSpace( num_elements_fn=_num_elements, aggregate_fn=_aggregate_grads, tensor_id=ops.tensor_id, zeros=_zeros, - ones=array_ops.ones) + ones=_ones) class GradientTape(object): diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index 435505edd74..9849f0f322e 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -170,6 +170,18 @@ class MicroBenchmarks(test.Benchmark): m = self._m_2 self._run(lambda: gen_array_ops.identity(m), 30000) + def benchmark_tfe_py_execute_identity(self): + m = self._m_2 + ctx_handle = context.context()._handle + attrs = ("T", self._m_2.dtype.as_datatype_enum) + inputs = [m] + + def f(): + pywrap_tensorflow.TFE_Py_Execute( + ctx_handle, None, "Identity", inputs, attrs, 1) + + self._run(f, 30000) + def benchmark_tf_gradient_function_identity(self): m = self._m_2 self._run( diff --git a/tensorflow/python/eager/core.py b/tensorflow/python/eager/core.py index 3f3d38b9510..483b7172107 100644 --- a/tensorflow/python/eager/core.py +++ b/tensorflow/python/eager/core.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python import pywrap_tensorflow -from tensorflow.python.eager import memory_trace from tensorflow.python.framework import errors # Trace of execution and memory usage. @@ -48,28 +47,3 @@ class _NotOkStatusException(Exception): pywrap_tensorflow.TFE_Py_RegisterExceptionClass(_NotOkStatusException) - - -def enable_tracing(): - """Enables tracing of execution and memory usage. - - WARNING: tracing is not thread-safe. - """ - # TODO(alive): Add code example in doc string. - global _active_trace - _active_trace = memory_trace.MemoryTrace() - - -def flush_trace(): - """Flushes the active trace, if it exists. - - WARNING: tracing is not thread-safe. - """ - # TODO(alive): Add code example in doc string. - if _active_trace is not None: - _active_trace.flush_trace() - - -def active_trace(): - """Returns the current global active trace of execution and memory usage.""" - return _active_trace diff --git a/tensorflow/python/eager/execute.py b/tensorflow/python/eager/execute.py index c6457232e91..e392c6bb53b 100644 --- a/tensorflow/python/eager/execute.py +++ b/tensorflow/python/eager/execute.py @@ -65,15 +65,7 @@ def execute(op_name, num_outputs, inputs, attrs, ctx, name=None): message = e.message six.raise_from(core._status_to_exception(e.code, message), None) - # TODO(alive, cais): Use the execution callback mechanism. - if core.active_trace() is not None: - for t in tensors: - core.active_trace().record_tensor(op_name, - ops.tensor_id(t), - t.device, - t.shape.num_elements()) # pylint: enable=protected-access - # TODO(cais): Optimize this, perhaps by replacing this execute function with # a different one when there are execution callback(s). for callback in ctx.post_execution_callbacks: @@ -168,8 +160,11 @@ def make_tensor(v, arg_name): def args_to_matching_eager(l, ctx, default_dtype=None): """Convert sequence `l` to eager same-type Tensors.""" EagerTensor = ops.EagerTensor # pylint: disable=invalid-name - if all(isinstance(x, EagerTensor) for x in l): - return l[0].dtype, l + for x in l: + if not isinstance(x, EagerTensor): + break + else: # note: intentional for-else + return l[0]._datatype_enum(), l # pylint: disable=protected-access # TODO(josh11b): Could we do a better job if we also passed in the # allowed dtypes when that was known? @@ -193,7 +188,7 @@ def args_to_matching_eager(l, ctx, default_dtype=None): else: ret = [internal_convert_to_tensor(t, dtype, ctx=ctx) for t in l] - return dtype, ret + return dtype.as_datatype_enum, ret def convert_to_mixed_eager_tensors(values, ctx): @@ -202,7 +197,7 @@ def convert_to_mixed_eager_tensors(values, ctx): t, context=ctx._handle, device=ctx.device_name) # pylint: disable=protected-access for t in values ] - types = [t.dtype for t in v] + types = [t._datatype_enum() for t in v] # pylint: disable=protected-access return types, v @@ -240,5 +235,5 @@ def args_to_mixed_eager_tensors(lists, ctx): for j in range(len(lists)): lists_ret[j].append( ops.internal_convert_to_tensor(lists[j][i], dtype=dtype, ctx=ctx)) - types.append(dtype) + types.append(dtype.as_datatype_enum) return types, lists_ret diff --git a/tensorflow/python/eager/memory_trace.py b/tensorflow/python/eager/memory_trace.py deleted file mode 100644 index 094bcab9e2e..00000000000 --- a/tensorflow/python/eager/memory_trace.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utility to trace per-device memory consumption across time over execution.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections - -TraceEntry = collections.namedtuple( - "TraceEntry", ["op_name", "tensor_id", "mem_usage", "device", "size"]) -TensorData = collections.namedtuple( - "TensorData", ["op_name", "tensor_size", "device"]) - - -class MemoryTrace(object): - """Records a trace of memory usage over operation execution.""" - - def __init__(self): - - self.trace = [] - self.tensor_to_data = {} - self.current_device_mem_usage = collections.defaultdict(int) - - def record_tensor(self, op_name, tensor_id, device, size): - self.current_device_mem_usage[device] += size - self.tensor_to_data[tensor_id] = TensorData(op_name, size, device) - self.trace.append(TraceEntry(op_name, - tensor_id, - dict(self.current_device_mem_usage.items()), - device, - size)) - - def delete_tensor(self, tensor_id): - if tensor_id not in self.tensor_to_data: - return - data = self.tensor_to_data.pop(tensor_id, None) - if data is None: return - self.current_device_mem_usage[data.device] -= data.tensor_size - self.trace.append(TraceEntry(data.op_name, - tensor_id, - dict(self.current_device_mem_usage.items()), - data.device, - -data.tensor_size)) - - def flush_trace(self): - """Prints the formatted trace recorded so far.""" - longest_op_name = max(len(t.op_name) for t in self.trace) - longest_op_name = max(longest_op_name, len("op_name")) - longest_heap_size = max(max(len(str(d)) for d in t.mem_usage) - for t in self.trace) - longest_heap_size = max(longest_heap_size, len("d0")) - longest_id_len = max(len(str(t.tensor_id)) for t in self.trace) - longest_id_len = max(longest_id_len, 2) - first_line = [] - first_line.append("+/-") - first_line.append("op_name".ljust(longest_op_name)) - first_line.append("id".ljust(longest_id_len)) - for i in range(len(self.current_device_mem_usage)): - first_line.append(("d"+str(i)).ljust(longest_heap_size)) - first_line.append("size") - print(" | ".join(first_line)) - for t in self.trace: - line = [] - if t.size > 0: - line.append("+ ") - else: - line.append("- ") - line.append(t.op_name.ljust(longest_op_name)) - line.append(str(t.tensor_id).ljust(longest_id_len)) - for d in t.mem_usage: - line.append(str(d).ljust(longest_heap_size)) - line.append(str(t.size)) - print(" | ".join(line)) - self.trace = [] - print() diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py index 51550c9f514..70e23b93117 100644 --- a/tensorflow/python/eager/ops_test.py +++ b/tensorflow/python/eager/ops_test.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.layers import core from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import sparse_ops @@ -349,6 +350,9 @@ class OpsTest(test_util.TensorFlowTestCase): x = constant_op.constant(3.1415) self.assertEqual('3.14', '{:.2f}'.format(x)) + def testNoOpIsNone(self): + self.assertTrue(control_flow_ops.no_op() is None) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/python_eager_op_gen.cc b/tensorflow/python/eager/python_eager_op_gen.cc index e57488cb640..956fbdac50d 100644 --- a/tensorflow/python/eager/python_eager_op_gen.cc +++ b/tensorflow/python/eager/python_eager_op_gen.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb_text.h" @@ -100,8 +101,9 @@ string TensorPBString(const TensorProto& pb) { class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp { public: - GenEagerPythonOp(const OpDef& op_def, const string& function_name) - : python_op_gen_internal::GenPythonOp(op_def, function_name) { + GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def, + const string& function_name) + : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name) { op_name_ = function_name_; op_name_.Consume("_"); } @@ -139,8 +141,9 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp { std::unordered_map attr_expressions_; }; -string GetEagerPythonOp(const OpDef& op_def, const string& function_name) { - return GenEagerPythonOp(op_def, function_name).Code(); +string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def, + const string& function_name) { + return GenEagerPythonOp(op_def, api_def, function_name).Code(); } string GenEagerPythonOp::FlattenInputs( @@ -528,6 +531,8 @@ string GenEagerPythonOp::Code() { strings::StrAppend(&result_, " _result = _", op_def_.name(), "Output._make(_result)\n"); } + } else { + strings::StrAppend(&result_, " _result = None\n"); } strings::StrAppend(&result_, " return _result\n\n"); return prelude_ + result_; @@ -589,8 +594,6 @@ void GenEagerPythonOp::AddEagerInferredAttrs() { strings::StrAppend(&result_, " ", VectorToTuple(p), " = ", inputs_var, "\n"); } - strings::StrAppend(&result_, " ", var_name, " = ", var_name, - ".as_datatype_enum\n"); } else if (attr.type() == "list(type)") { // NOTE: We ignore default values for these attrs, since it is // unclear how you would use it, and the one use case is @@ -617,9 +620,6 @@ void GenEagerPythonOp::AddEagerInferredAttrs() { } strings::StrAppend(&result_, " ", var_name, ", ", inputs_var, " = ", conversion, "(", inputs_var, ", _ctx)\n"); - strings::StrAppend(&result_, " ", var_name, - " = [_t.as_datatype_enum for _t in ", var_name, - "]\n"); } } } @@ -667,7 +667,7 @@ void GenEagerPythonOp::AddEagerExecute(const string& num_outputs_expr) { WordWrap(return_prefix, return_args, kRightMargin), "\n"); } -string GetEagerPythonOps(const OpList& ops, +string GetEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs, const std::vector& hidden_ops, bool require_shapes, const string& source_file_name = "") { @@ -703,6 +703,7 @@ from tensorflow.python.framework import common_shapes as _common_shapes from tensorflow.python.framework import op_def_registry as _op_def_registry from tensorflow.python.framework import ops as _ops from tensorflow.python.framework import op_def_library as _op_def_library +from tensorflow.python.util.tf_export import tf_export )"); @@ -732,7 +733,9 @@ from tensorflow.python.framework import op_def_library as _op_def_library continue; } - strings::StrAppend(&result, GetEagerPythonOp(op_def, function_name)); + const auto* api_def = api_defs.GetApiDef(op_def.name()); + strings::StrAppend(&result, + GetEagerPythonOp(op_def, *api_def, function_name)); if (!require_shapes) { strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(), @@ -765,19 +768,21 @@ from tensorflow.python.framework import op_def_library as _op_def_library } // namespace -void PrintEagerPythonOps(const OpList& ops, +void PrintEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs, const std::vector& hidden_ops, bool require_shapes, const string& source_file_name) { - printf("%s", - GetEagerPythonOps(ops, hidden_ops, require_shapes, source_file_name) - .c_str()); + printf("%s", GetEagerPythonOps(ops, api_defs, hidden_ops, require_shapes, + source_file_name) + .c_str()); } string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len) { string op_list_str(op_list_buf, op_list_len); OpList ops; ops.ParseFromString(op_list_str); - return GetEagerPythonOps(ops, {}, false); + + ApiDefMap api_def_map(ops); + return GetEagerPythonOps(ops, api_def_map, {}, false); } } // namespace tensorflow diff --git a/tensorflow/python/eager/python_eager_op_gen.h b/tensorflow/python/eager/python_eager_op_gen.h index 250623850f2..f9dfdf0408f 100644 --- a/tensorflow/python/eager/python_eager_op_gen.h +++ b/tensorflow/python/eager/python_eager_op_gen.h @@ -18,6 +18,7 @@ limitations under the License. #include #include #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -26,7 +27,7 @@ namespace tensorflow { // in the output. Prints the output to stdout. // Optional fourth argument is the name of the original C++ source file // where the ops' REGISTER_OP() calls reside. -void PrintEagerPythonOps(const OpList& ops, +void PrintEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs, const std::vector& hidden_ops, bool require_shapes, const string& source_file_name = ""); diff --git a/tensorflow/python/eager/python_eager_op_gen_main.cc b/tensorflow/python/eager/python_eager_op_gen_main.cc index 9e4aa97ccc7..cd74c438ec6 100644 --- a/tensorflow/python/eager/python_eager_op_gen_main.cc +++ b/tensorflow/python/eager/python_eager_op_gen_main.cc @@ -20,15 +20,36 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" namespace tensorflow { namespace { +constexpr char kBaseApiDef[] = + "tensorflow/core/api_def/base_api/*.pbtxt"; +constexpr char kPythonApiDef[] = + "tensorflow/core/api_def/python_api/*.pbtxt"; +constexpr bool kUseApiDef = false; + void PrintAllPythonOps(const std::vector& hidden_ops) { OpList ops; OpRegistry::Global()->Export(false, &ops); - PrintEagerPythonOps(ops, hidden_ops, true /* require_shapes */); + + ApiDefMap api_def_map(ops); + if (kUseApiDef) { + Env* env = Env::Default(); + + std::vector base_api_files; + std::vector python_api_files; + TF_CHECK_OK(env->GetMatchingPaths(kBaseApiDef, &base_api_files)); + TF_CHECK_OK(env->GetMatchingPaths(kPythonApiDef, &python_api_files)); + + TF_CHECK_OK(api_def_map.LoadFileList(env, base_api_files)); + TF_CHECK_OK(api_def_map.LoadFileList(env, python_api_files)); + } + PrintEagerPythonOps(ops, api_def_map, hidden_ops, true /* require_shapes */); } } // namespace diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index 653f3ef84e3..91192fea62d 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -330,24 +330,9 @@ void EagerTensor_dealloc(EagerTensor* self) { // We have the global interpreter lock, so use this chance to perform delayed // refcount decrements. tensorflow::ClearDecrefCache(); - PyObject* id = PyLong_FromLongLong(self->id); - PyObject* func = PyObject_GetAttrString(reinterpret_cast(self), - "_delete_trace"); + auto id = self->id; Py_TYPE(self)->tp_free(self); - self = nullptr; - // Note that we run `func` after calling `tp_free`. Otherwise calling that - // function can potentially trigger garbage collection that observes `self` - // in this half deleted state and crashes. - // Note that `func` is a staticmethod and does not need `self` to be around - // for running. - // We clear (and later restore) any errors that have already been set. Else - // these erorrs may appear randomly as part of the function execution. - PyObject *a, *b, *c; - PyErr_Fetch(&a, &b, &c); - PyObject_CallFunctionObjArgs(func, id, nullptr); - PyErr_Restore(a, b, c); - Py_DECREF(func); - Py_DECREF(id); + TFE_Py_TapeStackDeleteTrace(id); } // Getter for `_id`. diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index a67519f9a22..f96245f7a53 100644 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -87,22 +87,36 @@ TFE_TensorHandle* EagerTensor_Handle(const PyObject* o); // newly created type, or nullptr on error. PyObject* TFE_Py_InitEagerTensor(PyObject* base_class); -PyObject* TFE_Py_NewTape(); -PyObject* TFE_Py_TapeShouldRecord(PyObject* py_tape, PyObject* tensors); -void TFE_Py_TapeWatch(PyObject* tape, tensorflow::int64 tensor_id); -void TFE_Py_TapeDeleteTrace(PyObject* tape, tensorflow::int64 tensor_id); +// Pushes a new tape into the thread-local stack. +void TFE_Py_TapeStackPushNew(); -// Records an operation in the gradient tape. `tape` should point to an object -// returned by TFE_Py_NewTape. op_type is a string for the operation type, used -// in the backprop code. output_tensors should be a list of python ops.Tensor -// objects. input_tensor_ids should be a list of python integers with the ids of -// the input tensors of the recorded operation. backward_function should be the -// function to be called during backprop to, given the gradients of the output -// tensors, produce the gradients of the input tensors. -void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type, - PyObject* output_tensors, - PyObject* input_tensor_ids, - PyObject* backward_function); +// Pops the tape from the top of the stack and returns it. +PyObject* TFE_Py_TapeStackPop(); + +// Pushes an existing tape onto the stack. +void TFE_Py_TapeStackPush(PyObject* tape); + +// Returns true if the tape stack is empty. +PyObject* TFE_Py_TapeStackIsEmpty(); + +PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors); +void TFE_Py_TapeStackWatch(PyObject* tensor); +void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id); + +// Records an operation in the gradient tape stack.type is a string for the +// operation type, used in the backprop code. output_tensors should be a list of +// python ops.Tensor objects. input_tensor_ids should be a list of python +// integers with the ids of the input tensors of the recorded +// operation. backward_function should be the function to be called during +// backprop to, given the gradients of the output tensors, produce the gradients +// of the input tensors. +void TFE_Py_TapeStackRecordOperation(PyObject* op_type, + PyObject* output_tensors, + PyObject* input_tensor_ids, + PyObject* backward_function); + +// Watches the given variable object on the given tape. +void TFE_Py_TapeStackWatchVariable(PyObject* variable); // Computes a gradient based on information recorded on the tape.`tape` must // have been produced by TFE_Py_NewTape. `vspace` must be a @@ -114,9 +128,6 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, PyObject* target, PyObject* sources, PyObject* output_gradients, TF_Status* status); -// Watches the given variable object on the given tape. -void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable); - // Returns the set of variables watched by the given tape. PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape); diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 5cb1313c4b0..387eec13584 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/python/eager/pywrap_tfe.h" #include "tensorflow/c/c_api.h" @@ -525,12 +527,65 @@ static PyTypeObject TFE_Py_Tape_Type = { "TFE_Py_Tape objects", /* tp_doc */ }; -PyObject* TFE_Py_NewTape() { +// xcode 7 doesn't define thread_local, so for compatibility we implement our +// own. TODO(apassos) remove once we can deprecate xcode 7. +#ifndef __APPLE__ +thread_local std::vector* tape_stack = nullptr; +std::vector* GetTapeStack() { + if (tape_stack == nullptr) { + tape_stack = new std::vector; + } + return tape_stack; +} +#else +static tensorflow::mutex stack_mu(tensorflow::LINKER_INITIALIZED); +static std::unordered_map*>* + tape_stack GUARDED_BY(stack_mu) = nullptr; +std::vector* GetTapeStack() { + tensorflow::mutex_lock ml(stack_mu); + if (tape_stack == nullptr) { + tape_stack = + new std::unordered_map*>; + } + auto it = tape_stack->find(std::this_thread::get_id()); + if (it != tape_stack->end()) { + return it->second; + } + return tape_stack + ->emplace(std::this_thread::get_id(), new std::vector) + .first->second; +} +#endif + +void TFE_Py_TapeStackPushNew() { TFE_Py_Tape_Type.tp_new = PyType_GenericNew; - if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr; + if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return; TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type); tape->tape = new GradientTape(); - return reinterpret_cast(tape); + GetTapeStack()->push_back(tape); +} + +void TFE_Py_TapeStackPush(PyObject* tape) { + Py_INCREF(tape); + GetTapeStack()->push_back(reinterpret_cast(tape)); +} + +PyObject* TFE_Py_TapeStackIsEmpty() { + if (GetTapeStack()->empty()) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; +} + +PyObject* TFE_Py_TapeStackPop() { + auto* stack = GetTapeStack(); + if (stack->empty()) { + PyErr_SetString(PyExc_RuntimeError, "tape stack is empty."); + return nullptr; + } + TFE_Py_Tape* top = stack->back(); + stack->pop_back(); + return reinterpret_cast(top); } static std::vector MakeIntList(PyObject* list) { @@ -557,10 +612,14 @@ static std::vector MakeIntList(PyObject* list) { return tensor_ids; } -PyObject* TFE_Py_TapeShouldRecord(PyObject* py_tape, PyObject* tensors) { +PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors) { if (tensors == Py_None) { Py_RETURN_FALSE; } + auto* stack = GetTapeStack(); + if (stack->empty()) { + Py_RETURN_FALSE; + } PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); if (seq == nullptr) { return nullptr; @@ -575,16 +634,22 @@ PyObject* TFE_Py_TapeShouldRecord(PyObject* py_tape, PyObject* tensors) { tensor_ids.push_back(FastTensorId(item)); } Py_DECREF(seq); - TFE_Py_Tape* tape = reinterpret_cast(py_tape); - if (tape->tape->ShouldRecord(tensor_ids)) { - Py_RETURN_TRUE; - } else { - Py_RETURN_FALSE; + for (TFE_Py_Tape* tape : *stack) { + if (tape->tape->ShouldRecord(tensor_ids)) { + Py_RETURN_TRUE; + } } + Py_RETURN_FALSE; } -void TFE_Py_TapeWatch(PyObject* tape, tensorflow::int64 tensor_id) { - reinterpret_cast(tape)->tape->Watch(tensor_id); +void TFE_Py_TapeStackWatch(PyObject* tensor) { + tensorflow::int64 tensor_id = FastTensorId(tensor); + if (PyErr_Occurred()) { + return; + } + for (TFE_Py_Tape* tape : *GetTapeStack()) { + tape->tape->Watch(tensor_id); + } } static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { @@ -646,8 +711,10 @@ std::vector MakeTensorIDList(PyObject* tensors) { return list; } -void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) { - reinterpret_cast(tape)->tape->WatchVariable(variable); +void TFE_Py_TapeStackWatchVariable(PyObject* variable) { + for (TFE_Py_Tape* tape : *GetTapeStack()) { + tape->tape->WatchVariable(variable); + } } PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { @@ -661,10 +728,14 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { return result; } -void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type, - PyObject* output_tensors, - PyObject* input_tensors, - PyObject* backward_function) { +void TFE_Py_TapeStackRecordOperation(PyObject* op_type, + PyObject* output_tensors, + PyObject* input_tensors, + PyObject* backward_function) { + auto* stack = GetTapeStack(); + if (stack->empty()) { + return; + } std::vector input_ids = MakeTensorIDList(input_tensors); std::vector output_info; PyObject* seq = PySequence_Fast(output_tensors, @@ -697,14 +768,18 @@ void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type, return; } - Py_INCREF(backward_function); - reinterpret_cast(tape)->tape->RecordOperation( - op_type_str, output_info, input_ids, backward_function, - [backward_function]() { Py_DECREF(backward_function); }); + for (TFE_Py_Tape* tape : *stack) { + Py_INCREF(backward_function); + tape->tape->RecordOperation( + op_type_str, output_info, input_ids, backward_function, + [backward_function]() { Py_DECREF(backward_function); }); + } } -void TFE_Py_TapeDeleteTrace(PyObject* tape, tensorflow::int64 tensor_id) { - reinterpret_cast(tape)->tape->DeleteTrace(tensor_id); +void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id) { + for (TFE_Py_Tape* tape : *GetTapeStack()) { + tape->tape->DeleteTrace(tensor_id); + } } class PyVSpace : public tensorflow::eager::VSpace { diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index fb6b62a3e09..440c84b7ea9 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -18,106 +18,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import contextlib -import threading from tensorflow.python import pywrap_tensorflow -def tid(tensor): - return tensor._id # pylint: disable=protected-access - - -class TapeEntry( - collections.namedtuple("TapeEntry", [ - "op_type", - "output_ids", "input_ids", "backward_function", - "output_shape_and_dtype", - ])): - """Entry in the gradient tape. - - Represents the execution of one op or function, with instructions for doing - its backward pass and useful information for it. - - Args: - output_ids: tensor_id(t) for each output tensor T - input_ids: tensor_id(t) for each input tensor T - backward_function: function to be called with the downstream gradients and - side outputs as arguments which computes the backward pass. - output_shape_and_dtype: a list of (shape_tuple, dtype) for every output - tensor_id - """ - - -def _tensor_shape(t): - return t._shape_tuple() # pylint: disable=protected-access - - class Tape(object): """Represents a gradient propagation trace.""" - def __init__(self): - self._tape = pywrap_tensorflow.TFE_Py_NewTape() - - def should_record(self, tensors): - """Returns true if any tensor should be recorded. - - Args: - tensors: some tensors. - - Returns: - True if any of the tensors is in the tape. - """ - return pywrap_tensorflow.TFE_Py_TapeShouldRecord( - self._tape, tensors) - - def watch(self, tensor): - """Adds a tensor to the tape.""" - pywrap_tensorflow.TFE_Py_TapeWatch(self._tape, tid(tensor)) - - def watch_variable(self, v): - pywrap_tensorflow.TFE_Py_TapeWatchVariable(self._tape, v) + def __init__(self, tape): + self._tape = tape def watched_variables(self): return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape) - def record_operation(self, op_type, output_tensors, input_tensors, - backward_function): - """Records an operation in the tape.""" - pywrap_tensorflow.TFE_Py_TapeRecordOperation( - self._tape, - op_type, - output_tensors, - input_tensors, - backward_function) - - def _delete_tensor_id(self, i): - pywrap_tensorflow.TFE_Py_TapeDeleteTrace(self._tape, i) - - def delete_trace(self, tensor_id): - """Deletes any trace we have for this tensor.""" - self._delete_tensor_id(tensor_id) - - -class _TapeStack(threading.local): - - def __init__(self): - super(_TapeStack, self).__init__() - self._stack = [] - - @property - def stack(self): - return self._stack - - -# The global tape stack. -_tape_stack = _TapeStack() - def push_new_tape(): """Pushes a new tape onto the tape stack.""" - _tape_stack.stack.append(Tape()) + pywrap_tensorflow.TFE_Py_TapeStackPushNew() def watch(tensor): @@ -126,8 +44,7 @@ def watch(tensor): Args: tensor: tensor to be watched. """ - for t in _tape_stack.stack: - t.watch(tensor) + pywrap_tensorflow.TFE_Py_TapeStackWatch(tensor) def watch_variable(variable): @@ -136,48 +53,42 @@ def watch_variable(variable): Args: variable: variable to be watched. """ - for t in _tape_stack.stack: - t.watch_variable(variable) + pywrap_tensorflow.TFE_Py_TapeStackWatchVariable(variable) def pop_tape(): """Pops the top tape in the stack, if any.""" - if _tape_stack.stack: - return _tape_stack.stack.pop() - return None + return Tape(pywrap_tensorflow.TFE_Py_TapeStackPop()) @contextlib.contextmanager def stop_recording(): - old = _tape_stack.stack - _tape_stack._stack = [] # pylint: disable=protected-access + stack = [] + while not pywrap_tensorflow.TFE_Py_TapeStackIsEmpty(): + stack.append(pop_tape()._tape) # pylint: disable=protected-access try: yield finally: - _tape_stack._stack = old # pylint: disable=protected-access + for tape in reversed(stack): + pywrap_tensorflow.TFE_Py_TapeStackPush(tape) def should_record(tensors): """Returns true if any tape in the stack watches any of these tensors.""" - if not _tape_stack.stack: - return False - return any(x.should_record(tensors) for x in _tape_stack.stack) + return pywrap_tensorflow.TFE_Py_TapeStackShouldRecord(tensors) def record_operation(op_type, output_tensors, input_tensors, backward_function): """Records the operation on all tapes in the stack.""" - for t in _tape_stack.stack: - t.record_operation(op_type, output_tensors, - input_tensors, - backward_function) + pywrap_tensorflow.TFE_Py_TapeStackRecordOperation( + op_type, output_tensors, input_tensors, backward_function) def delete_trace(tensor_id): """Deletes traces for this Tensor from all tapes in the stack.""" - for t in _tape_stack.stack: - t.delete_trace(tensor_id) + pywrap_tensorflow.TFE_Py_TapeStackDeleteTrace(tensor_id) def could_possibly_record(): """Returns True if any tape is active.""" - return len(_tape_stack.stack) > 0 # pylint: disable=g-explicit-length-test + return not pywrap_tensorflow.TFE_Py_TapeStackIsEmpty() diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py index d13ecd13a1b..fa5d02c4767 100644 --- a/tensorflow/python/estimator/canned/head.py +++ b/tensorflow/python/estimator/canned/head.py @@ -1081,7 +1081,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): if mode == model_fn.ModeKeys.EVAL: # Estimator already adds a metric for loss. eval_metric_ops = { - metric_keys.MetricKeys.LOSS_MEAN: + _summary_key(self._name, metric_keys.MetricKeys.LOSS_MEAN): metrics_lib.mean( # Both values and weights here are reduced, scalar Tensors. # values is the actual mean we want -- weights represents diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py index 4497cd26f2d..f3afd84125d 100644 --- a/tensorflow/python/estimator/canned/head_test.py +++ b/tensorflow/python/estimator/canned/head_test.py @@ -2325,6 +2325,24 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): self.assertAllClose(expected_loss_mean, loss_mean) self.assertAllClose(expected_loss_mean, loss_mean_value_op.eval()) + def test_eval_metric_ops_with_head_name_for_regression(self): + head = head_lib._regression_head_with_mean_squared_error_loss( + name='some_regression_head') + logits = np.array(((1,), (9,)), dtype=np.float32) + labels = np.array(((1,), (1,)), dtype=np.int64) + features = {'x': np.array(((42,),), dtype=np.int32)} + # Create estimator spec. + spec = head.create_estimator_spec( + features=features, + mode=model_fn.ModeKeys.EVAL, + logits=logits, + labels=labels) + + expected_metric_keys = [ + '{}/some_regression_head'.format(metric_keys.MetricKeys.LOSS_MEAN), + ] + self.assertItemsEqual(expected_metric_keys, spec.eval_metric_ops.keys()) + def test_train_create_loss(self): head = head_lib._regression_head_with_mean_squared_error_loss() logits = np.array(((45,), (41,),), dtype=np.float32) diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 2d036e2cfba..f267f4a54e5 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -461,8 +461,12 @@ class Estimator(object): assets_extra=None, as_text=False, checkpoint_path=None): + # pylint: disable=line-too-long """Exports inference graph as a SavedModel into given dir. + For a detailed guide, see + @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}. + This method builds a new graph by first calling the serving_input_receiver_fn to obtain feature `Tensor`s, and then calling this `Estimator`'s model_fn to generate the model graph based on those @@ -506,6 +510,7 @@ class Estimator(object): ValueError: if no serving_input_receiver_fn is provided, no export_outputs are provided, or no checkpoint can be found. """ + # pylint: enable=line-too-long if serving_input_receiver_fn is None: raise ValueError('serving_input_receiver_fn must be defined.') diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 190a25d4d79..5ff75162468 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -233,6 +233,8 @@ def input_layer(features, ordered_columns = [] for column in sorted(feature_columns, key=lambda x: x.name): ordered_columns.append(column) + # TODO(b/67952670): Implement a column._var_scope_name property and use + # that instead of column.name. with variable_scope.variable_scope(None, default_name=column.name): tensor = column._get_dense_tensor( # pylint: disable=protected-access builder, @@ -340,6 +342,8 @@ def linear_model(features, ordered_columns = [] builder = _LazyBuilder(features) for column in sorted(feature_columns, key=lambda x: x.name): + # TODO(b/67952670): Implement a column._var_scope_name property and use + # that instead of column.name. with variable_scope.variable_scope(None, default_name=column.name): ordered_columns.append(column) if isinstance(column, _CategoricalColumn): @@ -489,15 +493,36 @@ def embedding_column( representation (e.g., to feed to a DNN). Inputs must be a `_CategoricalColumn` created by any of the - `categorical_column_*` function. Here is an example embedding of an identity - column for a DNN model: + `categorical_column_*` function. Here is an example of using + `embedding_column` with `DNNClassifier`: ```python video_id = categorical_column_with_identity( key='video_id', num_buckets=1000000, default_value=0) columns = [embedding_column(video_id, 9),...] - features = tf.parse_example(..., features=make_parse_example_spec(columns)) - dense_tensor = input_layer(features, columns) + + estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...) + + label_column = ... + def input_fn(): + features = tf.parse_example( + ..., features=make_parse_example_spec(columns + [label_column])) + labels = features.pop(label_column.name) + return features, labels + + estimator.train(input_fn=input_fn, steps=100) + ``` + + Here is an example using `embedding_column` with model_fn: + + ```python + def model_fn(features, ...): + video_id = categorical_column_with_identity( + key='video_id', num_buckets=1000000, default_value=0) + columns = [embedding_column(video_id, 9),...] + dense_tensor = input_layer(features, columns) + # Form DNN layers, calculate loss, and return EstimatorSpec. + ... ``` Args: @@ -551,12 +576,144 @@ def embedding_column( dimension=dimension, combiner=combiner, initializer=initializer, + shared_embedding_collection_name=None, ckpt_to_load_from=ckpt_to_load_from, tensor_name_in_ckpt=tensor_name_in_ckpt, max_norm=max_norm, trainable=trainable) +def _shared_embedding_columns( + categorical_columns, dimension, combiner='mean', initializer=None, + shared_embedding_collection_name=None, ckpt_to_load_from=None, + tensor_name_in_ckpt=None, max_norm=None, trainable=True): + """List of `_DenseColumn`s that convert from sparse, categorical input. + + This is similar to `embedding_column`, except that that it produces a list of + embedding columns that share the same embedding weights. + + Use this when your inputs are sparse and of the same type (e.g. watched and + impression video IDs that share the same vocabulary), and you want to convert + them to a dense representation (e.g., to feed to a DNN). + + Inputs must be a list of `_CategoricalColumn` created by any of the + `categorical_column_*` function. They must all be of the same type and have + the same arguments except `key`. E.g. they can be + categorical_column_with_vocabulary_file with the same vocabulary_file. Some or + all columns could also be weighted_categorical_column. + + Here is an example embedding of two features for a DNNClassifier model: + + ```python + watched_video_id = categorical_column_with_vocabulary_file( + 'watched_video_id', video_vocabulary_file, video_vocabulary_size) + impression_video_id = categorical_column_with_vocabulary_file( + 'impression_video_id', video_vocabulary_file, video_vocabulary_size) + columns = shared_embedding_columns( + [watched_video_id, impression_video_id], dimension=10) + + estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...) + + label_column = ... + def input_fn(): + features = tf.parse_example( + ..., features=make_parse_example_spec(columns + [label_column])) + labels = features.pop(label_column.name) + return features, labels + + estimator.train(input_fn=input_fn, steps=100) + ``` + + Here is an example using `shared_embedding_columns` with model_fn: + + ```python + def model_fn(features, ...): + watched_video_id = categorical_column_with_vocabulary_file( + 'watched_video_id', video_vocabulary_file, video_vocabulary_size) + impression_video_id = categorical_column_with_vocabulary_file( + 'impression_video_id', video_vocabulary_file, video_vocabulary_size) + columns = shared_embedding_columns( + [watched_video_id, impression_video_id], dimension=10) + dense_tensor = input_layer(features, columns) + # Form DNN layers, calculate loss, and return EstimatorSpec. + ... + ``` + + Args: + categorical_columns: List of `_CategoricalColumn`s created by a + `categorical_column_with_*` function. These columns produce the sparse IDs + that are inputs to the embedding lookup. All columns must be of the same + type and have the same arguments except `key`. E.g. they can be + categorical_column_with_vocabulary_file with the same vocabulary_file. + Some or all columns could also be weighted_categorical_column. + dimension: An integer specifying dimension of the embedding, must be > 0. + combiner: A string specifying how to reduce if there are multiple entries + in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with + 'mean' the default. 'sqrtn' often achieves good accuracy, in particular + with bag-of-words columns. Each of this can be thought as example level + normalizations on the column. For more information, see + `tf.embedding_lookup_sparse`. + initializer: A variable initializer function to be used in embedding + variable initialization. If not specified, defaults to + `tf.truncated_normal_initializer` with mean `0.0` and standard deviation + `1/sqrt(dimension)`. + shared_embedding_collection_name: Optional name of the collection where + shared embedding weights are added. If not given, a reasonable name will + be chosen based on the names of `categorical_columns`. + ckpt_to_load_from: String representing checkpoint name/pattern from which to + restore column weights. Required if `tensor_name_in_ckpt` is not `None`. + tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from + which to restore the column weights. Required if `ckpt_to_load_from` is + not `None`. + max_norm: If not `None`, embedding values are l2-normalized to this value. + trainable: Whether or not the embedding is trainable. Default is True. + + Returns: + A list of `_DenseColumn`s that converts from sparse input. The order of + results follows the ordering of `categorical_columns`. + + Raises: + ValueError: if `dimension` not > 0. + ValueError: if any of the given `categorical_columns` is of different type + or has different arguments than the others. + ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt` + is specified. + ValueError: if `initializer` is specified and is not callable. + """ + if (dimension is None) or (dimension < 1): + raise ValueError('Invalid dimension {}.'.format(dimension)) + if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): + raise ValueError('Must specify both `ckpt_to_load_from` and ' + '`tensor_name_in_ckpt` or none of them.') + + if (initializer is not None) and (not callable(initializer)): + raise ValueError('initializer must be callable if specified.') + if initializer is None: + initializer = init_ops.truncated_normal_initializer( + mean=0.0, stddev=1 / math.sqrt(dimension)) + # TODO(b/67952670): Validate categorical_columns. + if not shared_embedding_collection_name: + # Sort the columns so the name is deterministic even if the user passes + # columns from an unsorted collection, such as dict.values(). + sorted_columns = sorted(categorical_columns, key=lambda x: x.name) + shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns) + shared_embedding_collection_name += '_shared_embedding' + + result = [] + for column in categorical_columns: + result.append(_EmbeddingColumn( + categorical_column=column, + dimension=dimension, + combiner=combiner, + initializer=initializer, + shared_embedding_collection_name=shared_embedding_collection_name, + ckpt_to_load_from=ckpt_to_load_from, + tensor_name_in_ckpt=tensor_name_in_ckpt, + max_norm=max_norm, + trainable=trainable)) + return result + + def numeric_column(key, shape=(1,), default_value=None, @@ -1847,14 +2004,18 @@ class _EmbeddingColumn( _DenseColumn, collections.namedtuple('_EmbeddingColumn', ( 'categorical_column', 'dimension', 'combiner', 'initializer', - 'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable' + 'shared_embedding_collection_name', 'ckpt_to_load_from', + 'tensor_name_in_ckpt', 'max_norm', 'trainable' ))): - """See `_embedding_column`.""" + """See `embedding_column`.""" @property def name(self): if not hasattr(self, '_name'): - self._name = '{}_embedding'.format(self.categorical_column.name) + if self.shared_embedding_collection_name: + self._name = '{}_shared_embedding'.format(self.categorical_column.name) + else: + self._name = '{}_embedding'.format(self.categorical_column.name) return self._name @property @@ -1877,14 +2038,47 @@ class _EmbeddingColumn( sparse_ids = sparse_tensors.id_tensor sparse_weights = sparse_tensors.weight_tensor - # Create embedding weight, and restore from checkpoint if necessary. - embedding_weights = variable_scope.get_variable( - name='embedding_weights', - shape=(self.categorical_column._num_buckets, self.dimension), # pylint: disable=protected-access - dtype=dtypes.float32, - initializer=self.initializer, - trainable=self.trainable and trainable, - collections=weight_collections) + embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access + if self.shared_embedding_collection_name: + shared_embedding_collection = ops.get_collection( + self.shared_embedding_collection_name) + if shared_embedding_collection: + if len(shared_embedding_collection) > 1: + raise ValueError( + 'Collection {} can only contain one variable. ' + 'Suggested fix A: Choose a unique name for this collection. ' + 'Suggested fix B: Do not add any variables to this collection. ' + 'The feature_column library already adds a variable under the ' + 'hood.'.format(shared_embedding_collection)) + embedding_weights = shared_embedding_collection[0] + if embedding_weights.shape != embedding_shape: + raise ValueError( + 'Shared embedding collection {} contains variable {} of ' + 'unexpected shape {}. Expected shape is {}. ' + 'Suggested fix A: Choose a unique name for this collection. ' + 'Suggested fix B: Do not add any variables to this collection. ' + 'The feature_column library already adds a variable under the ' + 'hood.'.format( + self.shared_embedding_collection_name, embedding_weights.name, + embedding_weights.shape, embedding_shape)) + else: + embedding_weights = variable_scope.get_variable( + name=self.shared_embedding_collection_name + '_weights', + shape=embedding_shape, + dtype=dtypes.float32, + initializer=self.initializer, + trainable=self.trainable and trainable, + collections=weight_collections) + ops.add_to_collection( + self.shared_embedding_collection_name, embedding_weights) + else: + embedding_weights = variable_scope.get_variable( + name='embedding_weights', + shape=embedding_shape, + dtype=dtypes.float32, + initializer=self.initializer, + trainable=self.trainable and trainable, + collections=weight_collections) if self.ckpt_to_load_from is not None: to_restore = embedding_weights if isinstance(to_restore, variables.PartitionedVariable): diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index e57e9a9836c..4b06a85ad34 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -27,6 +27,7 @@ from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 from tensorflow.python.client import session from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column as fc_lib from tensorflow.python.feature_column import feature_column_lib as fc from tensorflow.python.feature_column.feature_column import _CategoricalColumn from tensorflow.python.feature_column.feature_column import _DenseColumn @@ -3403,6 +3404,7 @@ class EmbeddingColumnTest(test.TestCase): self.assertEqual('mean', embedding_column.combiner) self.assertIsNotNone(embedding_column.initializer) self.assertIsNone(embedding_column.ckpt_to_load_from) + self.assertIsNone(embedding_column.shared_embedding_collection_name) self.assertIsNone(embedding_column.tensor_name_in_ckpt) self.assertIsNone(embedding_column.max_norm) self.assertTrue(embedding_column.trainable) @@ -3426,6 +3428,7 @@ class EmbeddingColumnTest(test.TestCase): self.assertEqual(embedding_dimension, embedding_column.dimension) self.assertEqual('my_combiner', embedding_column.combiner) self.assertEqual('my_initializer', embedding_column.initializer()) + self.assertIsNone(embedding_column.shared_embedding_collection_name) self.assertEqual('my_ckpt', embedding_column.ckpt_to_load_from) self.assertEqual('my_ckpt_tensor', embedding_column.tensor_name_in_ckpt) self.assertEqual(42., embedding_column.max_norm) @@ -3456,6 +3459,7 @@ class EmbeddingColumnTest(test.TestCase): self.assertEqual(embedding_dimension, embedding_column.dimension) self.assertEqual('my_combiner', embedding_column.combiner) self.assertEqual('my_initializer', embedding_column.initializer()) + self.assertIsNone(embedding_column.shared_embedding_collection_name) self.assertEqual('my_ckpt', embedding_column.ckpt_to_load_from) self.assertEqual('my_ckpt_tensor', embedding_column.tensor_name_in_ckpt) self.assertEqual(42., embedding_column.max_norm) @@ -3979,6 +3983,269 @@ class EmbeddingColumnTest(test.TestCase): self.assertAllEqual(expected_lookups, input_layer.eval()) +class SharedEmbeddingColumnTest(test.TestCase): + + def test_defaults(self): + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=3) + embedding_dimension = 2 + embedding_column_b, embedding_column_a = fc_lib._shared_embedding_columns( + [categorical_column_b, categorical_column_a], + dimension=embedding_dimension) + self.assertIs(categorical_column_a, embedding_column_a.categorical_column) + self.assertIs(categorical_column_b, embedding_column_b.categorical_column) + self.assertEqual(embedding_dimension, embedding_column_a.dimension) + self.assertEqual(embedding_dimension, embedding_column_b.dimension) + self.assertEqual('mean', embedding_column_a.combiner) + self.assertEqual('mean', embedding_column_b.combiner) + self.assertIsNotNone(embedding_column_a.initializer) + self.assertIsNotNone(embedding_column_b.initializer) + self.assertIsNone(embedding_column_a.ckpt_to_load_from) + self.assertIsNone(embedding_column_b.ckpt_to_load_from) + self.assertEqual('aaa_bbb_shared_embedding', + embedding_column_a.shared_embedding_collection_name) + self.assertEqual('aaa_bbb_shared_embedding', + embedding_column_b.shared_embedding_collection_name) + self.assertIsNone(embedding_column_a.tensor_name_in_ckpt) + self.assertIsNone(embedding_column_b.tensor_name_in_ckpt) + self.assertIsNone(embedding_column_a.max_norm) + self.assertIsNone(embedding_column_b.max_norm) + self.assertTrue(embedding_column_a.trainable) + self.assertTrue(embedding_column_b.trainable) + self.assertEqual('aaa_shared_embedding', embedding_column_a.name) + self.assertEqual('bbb_shared_embedding', embedding_column_b.name) + self.assertEqual( + (embedding_dimension,), embedding_column_a._variable_shape) + self.assertEqual( + (embedding_dimension,), embedding_column_b._variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_a._parse_example_spec) + self.assertEqual({ + 'bbb': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_b._parse_example_spec) + + def test_all_constructor_args(self): + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=3) + embedding_dimension = 2 + embedding_column_a, embedding_column_b = fc_lib._shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + combiner='my_combiner', + initializer=lambda: 'my_initializer', + shared_embedding_collection_name='shared_embedding_collection_name', + ckpt_to_load_from='my_ckpt', + tensor_name_in_ckpt='my_ckpt_tensor', + max_norm=42., + trainable=False) + self.assertIs(categorical_column_a, embedding_column_a.categorical_column) + self.assertIs(categorical_column_b, embedding_column_b.categorical_column) + self.assertEqual(embedding_dimension, embedding_column_a.dimension) + self.assertEqual(embedding_dimension, embedding_column_b.dimension) + self.assertEqual('my_combiner', embedding_column_a.combiner) + self.assertEqual('my_combiner', embedding_column_b.combiner) + self.assertEqual('my_initializer', embedding_column_a.initializer()) + self.assertEqual('my_initializer', embedding_column_b.initializer()) + self.assertEqual('shared_embedding_collection_name', + embedding_column_a.shared_embedding_collection_name) + self.assertEqual('shared_embedding_collection_name', + embedding_column_b.shared_embedding_collection_name) + self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from) + self.assertEqual('my_ckpt', embedding_column_b.ckpt_to_load_from) + self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt) + self.assertEqual('my_ckpt_tensor', embedding_column_b.tensor_name_in_ckpt) + self.assertEqual(42., embedding_column_a.max_norm) + self.assertEqual(42., embedding_column_b.max_norm) + self.assertFalse(embedding_column_a.trainable) + self.assertFalse(embedding_column_b.trainable) + self.assertEqual('aaa_shared_embedding', embedding_column_a.name) + self.assertEqual('bbb_shared_embedding', embedding_column_b.name) + self.assertEqual( + (embedding_dimension,), embedding_column_a._variable_shape) + self.assertEqual( + (embedding_dimension,), embedding_column_b._variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_a._parse_example_spec) + self.assertEqual({ + 'bbb': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_b._parse_example_spec) + + def test_deep_copy(self): + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=3) + embedding_dimension = 2 + original_a, _ = fc_lib._shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + combiner='my_combiner', + initializer=lambda: 'my_initializer', + shared_embedding_collection_name='shared_embedding_collection_name', + ckpt_to_load_from='my_ckpt', + tensor_name_in_ckpt='my_ckpt_tensor', + max_norm=42., trainable=False) + for embedding_column_a in (original_a, copy.deepcopy(original_a)): + self.assertEqual('aaa', embedding_column_a.categorical_column.name) + self.assertEqual(3, embedding_column_a.categorical_column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_a.categorical_column._parse_example_spec) + + self.assertEqual(embedding_dimension, embedding_column_a.dimension) + self.assertEqual('my_combiner', embedding_column_a.combiner) + self.assertEqual('my_initializer', embedding_column_a.initializer()) + self.assertEqual('shared_embedding_collection_name', + embedding_column_a.shared_embedding_collection_name) + self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from) + self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt) + self.assertEqual(42., embedding_column_a.max_norm) + self.assertFalse(embedding_column_a.trainable) + self.assertEqual('aaa_shared_embedding', embedding_column_a.name) + self.assertEqual( + (embedding_dimension,), embedding_column_a._variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_a._parse_example_spec) + + def test_invalid_initializer(self): + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=3) + with self.assertRaisesRegexp(ValueError, 'initializer must be callable'): + fc_lib._shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2, + initializer='not_fn') + + def test_parse_example(self): + a = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) + b = fc.categorical_column_with_vocabulary_list( + key='bbb', vocabulary_list=('omar', 'stringer', 'marlo')) + a_embedded, b_embedded = fc_lib._shared_embedding_columns( + [a, b], dimension=2) + data = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'aaa': + feature_pb2.Feature(bytes_list=feature_pb2.BytesList( + value=[b'omar', b'stringer'])), + 'bbb': + feature_pb2.Feature(bytes_list=feature_pb2.BytesList( + value=[b'stringer', b'marlo'])), + })) + features = parsing_ops.parse_example( + serialized=[data.SerializeToString()], + features=fc.make_parse_example_spec([a_embedded, b_embedded])) + self.assertIn('aaa', features) + self.assertIn('bbb', features) + with self.test_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [0, 1]], + values=np.array([b'omar', b'stringer'], dtype=np.object_), + dense_shape=[1, 2]), + features['aaa'].eval()) + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [0, 1]], + values=np.array([b'stringer', b'marlo'], dtype=np.object_), + dense_shape=[1, 2]), + features['bbb'].eval()) + + def test_input_layer(self): + # Inputs. + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 5)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [0] + # example 1, ids [] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (3, 0)), + values=(0, 1), + dense_shape=(4, 5)) + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups = ( + # example 0: + # A ids [2], embedding = [7, 11] + # B ids [0], embedding = [1, 2] + (7., 11., 1., 2.), + # example 1: + # A ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + # B ids [], embedding = [0, 0] + (2., 3.5, 0., 0.), + # example 2: + # A ids [], embedding = [0, 0] + # B ids [], embedding = [0, 0] + (0., 0., 0., 0.), + # example 3: + # A ids [1], embedding = [3, 5] + # B ids [1], embedding = [3, 5] + (3., 5., 3., 5.), + ) + + # Build columns. + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_a, embedding_column_b = fc_lib._shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, initializer=_initializer) + + # Provide sparse input and get dense result. + input_layer = fc.input_layer( + features={'aaa': sparse_input_a, 'bbb': sparse_input_b}, + feature_columns=(embedding_column_b, embedding_column_a)) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ['input_layer/aaa_shared_embedding/aaa_bbb_shared_embedding_weights:0'], + tuple([v.name for v in global_vars])) + trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertItemsEqual( + ['input_layer/aaa_shared_embedding/aaa_bbb_shared_embedding_weights:0'], + tuple([v.name for v in trainable_vars])) + shared_embedding_vars = ops.get_collection('aaa_bbb_shared_embedding') + self.assertItemsEqual( + ['input_layer/aaa_shared_embedding/aaa_bbb_shared_embedding_weights:0'], + tuple([v.name for v in shared_embedding_vars])) + with _initialized_session(): + self.assertAllEqual(embedding_values, trainable_vars[0].eval()) + self.assertAllEqual(expected_lookups, input_layer.eval()) + + class WeightedCategoricalColumnTest(test.TestCase): def test_defaults(self): diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index d51e142da19..bf3be34d851 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -55,10 +55,10 @@ from tensorflow.python.framework import tensor_util def _eager_reshape(tensor, shape, ctx): """Eager-only version of Reshape op; requires tensor is an eager Tensor.""" - attr_t = tensor.dtype.as_datatype_enum + attr_t = tensor._datatype_enum() # pylint: disable=protected-access attr_tshape, (shape,) = execute.args_to_matching_eager( [shape], ctx, dtypes.int32) - attr_tshape = attr_tshape.as_datatype_enum + attr_tshape = attr_tshape inputs_flat = [tensor, shape] attrs = ("T", attr_t, "Tshape", attr_tshape) result, = execute.execute( diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index c964fe4418a..503e7657701 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -605,11 +605,6 @@ class Tensor(_TensorLike): class _EagerTensorBase(Tensor): """Base class for EagerTensor.""" - @staticmethod - def _delete_trace(tid): - """Helper function to be called by __del__ of the subclass.""" - tape.delete_trace(tid) - @property def dtype(self): # Note: using the intern table directly here as this is @@ -720,11 +715,6 @@ class _EagerTensorBase(Tensor): new_tensor = self._copy_to_device(context=ctx._handle, device=device_name) except core._NotOkStatusException as e: six.raise_from(core._status_to_exception(e.code, e.message), None) - if core.active_trace() is not None: - core.active_trace().record_tensor("COPY", - tensor_id(new_tensor), - new_tensor.device, - new_tensor.shape.num_elements()) # Record the copy on tape and define backprop copy as well. if not context.in_graph_mode(): @@ -1540,13 +1530,6 @@ class Operation(object): raise TypeError("input needs to be a Tensor: %s" % a) # Mark that we consume the inputs. a._add_consumer(self) # pylint: disable=protected-access - if output_types is None: - output_types = [] - self._output_types_val = output_types - self._outputs = [ - Tensor(self, i, output_type) - for i, output_type in enumerate(output_types) - ] if input_types is None: input_types = [i.dtype.base_dtype for i in self._inputs] else: @@ -1576,25 +1559,6 @@ class Operation(object): self._original_op = original_op self._op_def = op_def self._traceback = self._graph._extract_stack() # pylint: disable=protected-access - # Define self._c_op before calling self._control_flow_context.AddOp(), since - # that will call methods on this op that check if self._c_op is set. - self._c_op = None - # Add this op to the current control flow context: - self._control_flow_context = g._get_control_flow_context() # pylint: disable=protected-access - if self._control_flow_context is not None: - # TODO(skyewm): consider refactoring this to call self._create_c_op() - # first. This would require updating the TF_Operation's ID (see the - # comment and self._id_value update below). The disadvantage of calling - # AddOp() first is that we need to maintain Operation state that is - # accessed by AddOp() in Python, e.g. the input Tensors. - self._control_flow_context.AddOp(self) - # NOTE(keveman): Control flow context's AddOp could be creating new ops and - # setting op.inputs[index] = new_op. Thus the new ops' id could be larger - # than this op's id even though this op depend on them. Therefore, delaying - # assigning id to this op until all ops this could be dependent on are - # created. - self._id_value = self._graph._next_id() # pylint: disable=protected-access - self._recompute_node_def() if self._graph._c_graph: # pylint: disable=protected-access if self._op_def: @@ -1608,6 +1572,29 @@ class Operation(object): self._c_op = _create_c_op(self._graph, self._node_def, grouped_inputs, self._control_inputs) + else: + self._c_op = None + + # Initialize self._outputs + if output_types is None: + output_types = [] + self._output_types_val = output_types + self._outputs = [ + Tensor(self, i, output_type) + for i, output_type in enumerate(output_types) + ] + + # Add this op to the current control flow context: + self._control_flow_context = g._get_control_flow_context() # pylint: disable=protected-access + if self._control_flow_context is not None: + self._control_flow_context.AddOp(self) + # NOTE(keveman): Control flow context's AddOp could be creating new ops and + # setting op.inputs[index] = new_op. Thus the new ops' id could be larger + # than this op's id even though this op depend on them. Therefore, delaying + # assigning id to this op until all ops this could be dependent on are + # created. + self._id_value = self._graph._next_id() # pylint: disable=protected-access + self._recompute_node_def() def _reconstruct_sequence_inputs(self, op_def, inputs, attrs): """Regroups a flat list of input tensors into scalar and sequence inputs. diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index 3c62dfd133d..c57f0a98421 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -447,23 +447,48 @@ static void AddDelimiter(string* append_to, const string& delim) { if (!append_to->empty()) strings::StrAppend(append_to, delim); } -GenPythonOp::GenPythonOp(const OpDef& op_def, const string& function_name) +const ApiDef::Attr* FindAttr(StringPiece name, const ApiDef& api_def) { + for (int i = 0; i < api_def.attr_size(); ++i) { + if (api_def.attr(i).name() == name) { + return &api_def.attr(i); + } + } + return nullptr; +} + +const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) { + for (int i = 0; i < api_def.in_arg_size(); ++i) { + if (api_def.in_arg(i).name() == name) { + return &api_def.in_arg(i); + } + } + return nullptr; +} + +GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def, + const string& function_name) : op_def_(op_def), + api_def_(api_def), function_name_(function_name), num_outs_(op_def.output_arg_size()) {} GenPythonOp::~GenPythonOp() {} string GenPythonOp::Code() { + if (api_def_.visibility() == ApiDef::SKIP) { + return ""; + } // This has all the input args followed by those attrs that don't have // defaults. std::vector args_no_default; // The parameters with defaults (these have to be listed after those without). // No input args are included, just attrs. std::vector args_with_defaults; - for (int i = 0; i < op_def_.input_arg_size(); ++i) { - const auto& arg(op_def_.input_arg(i)); - args_no_default.push_back(arg.name()); + + for (int i = 0; i < api_def_.arg_order_size(); ++i) { + const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_); + const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_); + args_no_default.push_back(api_def_arg.rename_to()); if (!arg.type_attr().empty()) { gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_attr(), arg.name()); } else if (!arg.type_list_attr().empty()) { @@ -474,14 +499,14 @@ string GenPythonOp::Code() { gtl::InsertIfNotPresent(&inferred_attrs_, arg.number_attr(), arg.name()); } } - for (int i = 0; i < op_def_.attr_size(); ++i) { - const auto& attr(op_def_.attr(i)); + for (int i = 0; i < api_def_.attr_size(); ++i) { + const auto& attr(api_def_.attr(i)); // Do not add inferred attrs to the Python function signature. if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) { if (attr.has_default_value()) { - args_with_defaults.push_back(attr.name()); + args_with_defaults.push_back(attr.rename_to()); } else { - args_no_default.push_back(attr.name()); + args_no_default.push_back(attr.rename_to()); } } } @@ -515,6 +540,7 @@ string GenPythonOp::Code() { AddDelimiter(¶meters, ", "); strings::StrAppend(¶meters, "name=None"); + AddExport(); AddDefLine(parameters); AddDocStringDescription(); AddDocStringArgs(); @@ -530,18 +556,37 @@ string GenPythonOp::Code() { return prelude_ + result_; } +void GenPythonOp::AddExport() { + if (api_def_.visibility() != api_def_.VISIBLE) { + return; + } + strings::StrAppend(&result_, "tf_export("); + + // Add all endpoint names to tf_export. + bool first_endpoint = true; + for (const auto& endpoint : api_def_.endpoint()) { + if (!first_endpoint) { + strings::StrAppend(&result_, ", "); + } else { + first_endpoint = false; + } + strings::StrAppend(&result_, "'", endpoint.name(), "'"); + } + strings::StrAppend(&result_, ")\n"); +} + void GenPythonOp::AddDefLine(const string& parameters) { strings::StrAppend(&result_, "def ", function_name_, "(", parameters, "):\n"); } void GenPythonOp::AddDocStringDescription() { string comment; - if (op_def_.summary().empty()) { + if (api_def_.summary().empty()) { comment = "TODO: add doc.\n"; } else { - comment = strings::StrCat(op_def_.summary(), "\n"); - if (!op_def_.description().empty()) { - strings::StrAppend(&comment, "\n", Indent(2, 2, op_def_.description())); + comment = strings::StrCat(api_def_.summary(), "\n"); + if (!api_def_.description().empty()) { + strings::StrAppend(&comment, "\n", Indent(2, 2, api_def_.description())); } } strings::StrAppend(&result_, " r\"\"\"", comment, "\n"); @@ -552,9 +597,10 @@ void GenPythonOp::AddDocStringArgs() { } void GenPythonOp::AddDocStringInputs() { - for (int i = 0; i < op_def_.input_arg_size(); ++i) { - const auto& arg(op_def_.input_arg(i)); - StringPiece description = op_def_.input_arg(i).description(); + for (int i = 0; i < api_def_.arg_order_size(); ++i) { + const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_); + const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_); + StringPiece description = api_def_arg.description(); string desc; if (ConsumeEquals(&description)) { // Skip the generated type info. desc = strings::StrCat(param_names_[i], ": "); @@ -572,7 +618,9 @@ void GenPythonOp::AddDocStringInputs() { void GenPythonOp::AddDocStringAttrs() { for (const string& name : attrs_) { const auto& attr = *FindAttr(name, op_def_); - string desc = strings::StrCat(AvoidPythonReserved(name), ": "); + const auto& api_def_attr = *FindAttr(name, api_def_); + string desc = + strings::StrCat(AvoidPythonReserved(api_def_attr.rename_to()), ": "); static const char* const kAttrTypeName[][2] = { {"string", "`string`"}, @@ -596,7 +644,7 @@ void GenPythonOp::AddDocStringAttrs() { for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) { if (attr.type() == kAttrTypeName[i][0]) { string s; - if (attr.has_default_value()) { + if (api_def_attr.has_default_value()) { s = strings::StrCat("optional ", kAttrTypeName[i][1]); } else { s = kAttrTypeName[i][1]; @@ -625,14 +673,13 @@ void GenPythonOp::AddDocStringAttrs() { strings::StrAppend(&desc, "."); - if (attr.has_default_value()) { - strings::StrAppend(&desc, " Defaults to `", - AttrValueToPython(attr.type(), attr.default_value()), - "`."); + if (api_def_attr.has_default_value()) { + strings::StrAppend( + &desc, " Defaults to `", + AttrValueToPython(attr.type(), api_def_attr.default_value()), "`."); } - - if (!attr.description().empty()) { - AppendWithinWidth(&desc, attr.description(), + if (!api_def_attr.description().empty()) { + AppendWithinWidth(&desc, api_def_attr.description(), kRightMargin - 4 /* indent */); } strings::StrAppend(&result_, Indent(4, 6, desc)); @@ -650,8 +697,8 @@ void GenPythonOp::AddOutputGlobals() { // Prepare the list of output names std::vector out_names(num_outs_); for (int i = 0; i < num_outs_; ++i) { - if (!op_def_.output_arg(i).name().empty()) { - out_names[i] = op_def_.output_arg(i).name(); + if (!api_def_.out_arg(i).rename_to().empty()) { + out_names[i] = api_def_.out_arg(i).rename_to(); } else { out_names[i] = strings::StrCat("output", i); } @@ -714,11 +761,14 @@ void GenPythonOp::AddBodyNoReturn(const string& apply_prefix) { } // namespace python_op_gen_internal -string GetPythonOp(const OpDef& op_def, const string& function_name) { - return python_op_gen_internal::GenPythonOp(op_def, function_name).Code(); +string GetPythonOp(const OpDef& op_def, const ApiDef& api_def, + const string& function_name) { + return python_op_gen_internal::GenPythonOp(op_def, api_def, function_name) + .Code(); } -string GetPythonOps(const OpList& ops, const std::vector& hidden_ops, +string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs, + const std::vector& hidden_ops, bool require_shapes) { string result; // Header @@ -738,6 +788,7 @@ from tensorflow.python.framework import common_shapes as _common_shapes from tensorflow.python.framework import op_def_registry as _op_def_registry from tensorflow.python.framework import ops as _ops from tensorflow.python.framework import op_def_library as _op_def_library +from tensorflow.python.util.tf_export import tf_export )"); // We'll make a copy of ops that filters out descriptions. @@ -766,7 +817,8 @@ from tensorflow.python.framework import op_def_library as _op_def_library continue; } - strings::StrAppend(&result, GetPythonOp(op_def, function_name)); + const auto* api_def = api_defs.GetApiDef(op_def.name()); + strings::StrAppend(&result, GetPythonOp(op_def, *api_def, function_name)); if (!require_shapes) { strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(), @@ -799,16 +851,18 @@ from tensorflow.python.framework import op_def_library as _op_def_library return result; } -void PrintPythonOps(const OpList& ops, const std::vector& hidden_ops, +void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs, + const std::vector& hidden_ops, bool require_shapes) { - printf("%s", GetPythonOps(ops, hidden_ops, require_shapes).c_str()); + printf("%s", GetPythonOps(ops, api_defs, hidden_ops, require_shapes).c_str()); } string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) { string op_list_str(op_list_buf, op_list_len); OpList ops; ops.ParseFromString(op_list_str); - return GetPythonOps(ops, {}, false); + ApiDefMap api_def_map(ops); + return GetPythonOps(ops, api_def_map, {}, false); } } // namespace tensorflow diff --git a/tensorflow/python/framework/python_op_gen.h b/tensorflow/python/framework/python_op_gen.h index f485044c5af..4d20888dc63 100644 --- a/tensorflow/python/framework/python_op_gen.h +++ b/tensorflow/python/framework/python_op_gen.h @@ -18,20 +18,23 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { -// hidden_ops should be a comma-separated -// list of Op names that should get a leading _ in the output. +// hidden_ops should be a vector of Op names that should get a leading _ in the +// output. // The Print* version prints the output to stdout, Get* version returns the // output as a string. -void PrintPythonOps(const OpList& ops, const std::vector& hidden_ops, - bool require_shapes); -string GetPythonOps(const OpList& ops, const std::vector& hidden_ops, - bool require_shapes); -string GetPythonOp(const OpDef& op_def, const string& function_name); +void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs, + const std::vector& hidden_ops, bool require_shapes); +string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs, + const std::vector& hidden_ops, bool require_shapes); +string GetPythonOp(const OpDef& op_def, const ApiDef& api_def, + const string& function_name); // Get the python wrappers for a list of ops in a OpList. // `op_list_buf` should be a pointer to a buffer containing diff --git a/tensorflow/python/framework/python_op_gen_internal.h b/tensorflow/python/framework/python_op_gen_internal.h index 92237ac81a2..c1efbf9be22 100644 --- a/tensorflow/python/framework/python_op_gen_internal.h +++ b/tensorflow/python/framework/python_op_gen_internal.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/platform/types.h" @@ -42,7 +43,8 @@ string DataTypeToPython(DataType dtype, const string& dtype_module); class GenPythonOp { public: - GenPythonOp(const OpDef& op_def, const string& function_name); + GenPythonOp(const OpDef& op_def, const ApiDef& api_def, + const string& function_name); virtual ~GenPythonOp(); virtual string Code(); @@ -62,9 +64,11 @@ class GenPythonOp { void AddDocStringOutputs(); void AddBody(const string& prefix); void AddBodyNoReturn(const string& apply_prefix); + void AddExport(); // From constructor arguments const OpDef& op_def_; + const ApiDef& api_def_; const string function_name_; const int num_outs_; diff --git a/tensorflow/python/framework/python_op_gen_main.cc b/tensorflow/python/framework/python_op_gen_main.cc index f681daa7e46..61b1d02a5e8 100644 --- a/tensorflow/python/framework/python_op_gen_main.cc +++ b/tensorflow/python/framework/python_op_gen_main.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/inputbuffer.h" #include "tensorflow/core/lib/io/path.h" @@ -33,6 +34,12 @@ limitations under the License. namespace tensorflow { namespace { +constexpr char kBaseApiDef[] = + "tensorflow/core/api_def/base_api/*.pbtxt"; +constexpr char kPythonApiDef[] = + "tensorflow/core/api_def/python_api/*.pbtxt"; +constexpr bool kUseApiDef = false; + Status ReadOpListFromFile(const string& filename, std::vector* op_list) { std::unique_ptr file; @@ -108,6 +115,19 @@ void PrintAllPythonOps(const std::vector& op_list, OpList ops; OpRegistry::Global()->Export(false, &ops); + ApiDefMap api_def_map(ops); + if (kUseApiDef) { + Env* env = Env::Default(); + + std::vector base_api_files; + std::vector python_api_files; + TF_CHECK_OK(env->GetMatchingPaths(kBaseApiDef, &base_api_files)); + TF_CHECK_OK(env->GetMatchingPaths(kPythonApiDef, &python_api_files)); + + TF_CHECK_OK(api_def_map.LoadFileList(env, base_api_files)); + TF_CHECK_OK(api_def_map.LoadFileList(env, python_api_files)); + } + if (op_list_is_whitelist) { std::unordered_set whitelist(op_list.begin(), op_list.end()); OpList pruned_ops; @@ -116,9 +136,11 @@ void PrintAllPythonOps(const std::vector& op_list, *pruned_ops.mutable_op()->Add() = op_def; } } - PrintEagerPythonOps(pruned_ops, {}, require_shapes, source_file_name); + PrintEagerPythonOps(pruned_ops, api_def_map, {}, require_shapes, + source_file_name); } else { - PrintEagerPythonOps(ops, op_list, require_shapes, source_file_name); + PrintEagerPythonOps(ops, api_def_map, op_list, require_shapes, + source_file_name); } } diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index bc9d9104473..99a4d23b6aa 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -88,8 +88,12 @@ def loop(): def get_config(layout_optimizer=True): - rewrite_options = rewriter_config_pb2.RewriterConfig( - optimize_tensor_layout=layout_optimizer) + if layout_optimizer: + rewrite_options = rewriter_config_pb2.RewriterConfig( + layout_optimizer=rewriter_config_pb2.RewriterConfig.ON) + else: + rewrite_options = rewriter_config_pb2.RewriterConfig( + layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF) graph_options = config_pb2.GraphOptions( rewrite_options=rewrite_options, build_cost_model=1) config = config_pb2.ConfigProto(graph_options=graph_options) @@ -183,7 +187,8 @@ class LayoutOptimizerTest(test.TestCase): self.skipTest('GPU required') random_seed.set_random_seed(0) - x = random_ops.truncated_normal([1, 200, 200, 3], seed=0) + x = variables.Variable( + random_ops.truncated_normal([1, 200, 200, 3], seed=0)) y = conv_layers.conv2d(x, 32, [3, 3]) z = conv_layers.conv2d(y, 32, [3, 3]) optimizer = gradient_descent.GradientDescentOptimizer(1e-4) @@ -194,7 +199,7 @@ class LayoutOptimizerTest(test.TestCase): meta_graph = saver_lib.export_meta_graph(graph_def=graph.as_graph_def()) rewrite_options = rewriter_config_pb2.RewriterConfig( - optimize_tensor_layout=True) + layout_optimizer=rewriter_config_pb2.RewriterConfig.ON) optimized_graph = tf_optimizer.OptimizeGraph(rewrite_options, meta_graph) found = 0 diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 55b5d7ff613..e4992afbca7 100644 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -30,6 +30,7 @@ py_library( "_impl/keras/datasets/cifar.py", "_impl/keras/datasets/cifar10.py", "_impl/keras/datasets/cifar100.py", + "_impl/keras/datasets/fashion_mnist.py", "_impl/keras/datasets/imdb.py", "_impl/keras/datasets/mnist.py", "_impl/keras/datasets/reuters.py", @@ -89,6 +90,7 @@ py_library( "datasets/boston_housing/__init__.py", "datasets/cifar10/__init__.py", "datasets/cifar100/__init__.py", + "datasets/fashion_mnist/__init__.py", "datasets/imdb/__init__.py", "datasets/mnist/__init__.py", "datasets/reuters/__init__.py", @@ -588,6 +590,18 @@ py_test( ], ) +py_test( + name = "np_utils_test", + size = "small", + srcs = ["_impl/keras/utils/np_utils_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":keras", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + py_test( name = "training_utils_test", size = "medium", diff --git a/tensorflow/python/keras/_impl/keras/__init__.py b/tensorflow/python/keras/_impl/keras/__init__.py index f0e8d91a929..74cc9d0488c 100644 --- a/tensorflow/python/keras/_impl/keras/__init__.py +++ b/tensorflow/python/keras/_impl/keras/__init__.py @@ -40,4 +40,4 @@ from tensorflow.python.keras._impl.keras.layers import Input from tensorflow.python.keras._impl.keras.models import Model from tensorflow.python.keras._impl.keras.models import Sequential -__version__ = '2.0.8-tf' +__version__ = '2.1.1-tf' diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py index f9a53c4eb4d..b029e5161f7 100644 --- a/tensorflow/python/keras/_impl/keras/backend.py +++ b/tensorflow/python/keras/_impl/keras/backend.py @@ -2486,11 +2486,21 @@ def print_tensor(x, message=''): class Function(object): """Runs a computation graph. + It's possible to pass arguments to `tf.Session.run()` via `session_kwargs`. + In particular additonal operations via `fetches` argument and additional + tensor substitutions via `feed_dict` arguments. Note that given + substitutions are merged with substitutions from `inputs`. Even though + `feed_dict` is passed once in the constructor (called in `model.compile()`) + we can modify the values in the dictionary. Through this feed_dict we can + provide additional substitutions besides Keras inputs. + Arguments: inputs: Feed placeholders to the computation graph. outputs: Output tensors to fetch. updates: Additional update ops to be run at function call. - name: a name to help users identify what this function does. + name: A name to help users identify what this function does. + session_kwargs: Arguments to `tf.Session.run()`: `fetches`, `feed_dict`, + `options`, `run_metadata` """ def __init__(self, inputs, outputs, updates=None, name=None, @@ -2518,12 +2528,18 @@ class Function(object): updates_ops.append(update) self.updates_op = control_flow_ops.group(*updates_ops) self.name = name + # additional tensor substitutions + self.feed_dict = session_kwargs.pop('feed_dict', {}) + # additional operations + self.fetches = session_kwargs.pop('fetches', []) + if not isinstance(self.fetches, list): + self.fetches = [self.fetches] self.session_kwargs = session_kwargs def __call__(self, inputs): if not isinstance(inputs, (list, tuple)): raise TypeError('`inputs` should be a list or tuple.') - feed_dict = {} + feed_dict = self.feed_dict.copy() for tensor, value in zip(self.inputs, inputs): if is_sparse(tensor): sparse_coo = value.tocoo() @@ -2531,11 +2547,10 @@ class Function(object): np.expand_dims(sparse_coo.col, 1)), 1) value = (indices, sparse_coo.data, sparse_coo.shape) feed_dict[tensor] = value + fetches = self.outputs + [self.updates_op] + self.fetches session = get_session() updated = session.run( - self.outputs + [self.updates_op], - feed_dict=feed_dict, - **self.session_kwargs) + fetches=fetches, feed_dict=feed_dict, **self.session_kwargs) return updated[:len(self.outputs)] diff --git a/tensorflow/python/keras/_impl/keras/backend_test.py b/tensorflow/python/keras/_impl/keras/backend_test.py index 5eaae31d921..e45e566dcac 100644 --- a/tensorflow/python/keras/_impl/keras/backend_test.py +++ b/tensorflow/python/keras/_impl/keras/backend_test.py @@ -165,6 +165,55 @@ class BackendUtilsTest(test.TestCase): for y in ys: self.assertEqual(y.op.name[:12], 'StopGradient') + def test_function_tf_fetches(self): + # Additional operations can be passed to tf.Session().run() via its + # `fetches` arguments. In contrast to `updates` argument of + # keras.backend.function() these do not have control dependency on `outputs` + # so they can run in parallel. Also they should not contribute to output of + # keras.backend.function(). + with self.test_session(): + x = keras.backend.variable(0.) + y = keras.backend.variable(0.) + x_placeholder = keras.backend.placeholder(shape=()) + y_placeholder = keras.backend.placeholder(shape=()) + + f = keras.backend.function(inputs=[x_placeholder, y_placeholder], + outputs=[x_placeholder + y_placeholder], + updates=[(x, x_placeholder + 1.)], + fetches=[keras.backend.update(y, 5.)]) + output = f([10., 20.]) + assert output == [30.] + assert keras.backend.get_session().run(fetches=[x, y]) == [11., 5.] + + def test_function_tf_feed_dict(self): + # Additional substitutions can be passed to `tf.Session().run()` via its + # `feed_dict` arguments. Note that the feed_dict is passed once in the + # constructor but we can modify the values in the dictionary. Through + # this feed_dict we can provide additional substitutions besides Keras + # inputs. + with self.test_session(): + x = keras.backend.variable(0.) + y = keras.backend.variable(0.) + x_placeholder = keras.backend.placeholder(shape=()) + y_placeholder = keras.backend.placeholder(shape=()) + + feed_dict = {y_placeholder: 3.} + fetches = [keras.backend.update(y, y_placeholder * 10.)] + f = keras.backend.function(inputs=[x_placeholder], + outputs=[x_placeholder + 1.], + updates=[(x, x_placeholder + 10.)], + feed_dict=feed_dict, + fetches=fetches) + output = f([10.]) + assert output == [11.] + assert keras.backend.get_session().run(fetches=[x, y]) == [20., 30.] + + # updated value in feed_dict will be modified within the K.function() + feed_dict[y_placeholder] = 4. + output = f([20.]) + assert output == [21.] + assert keras.backend.get_session().run(fetches=[x, y]) == [30., 40.] + class BackendVariableTest(test.TestCase): diff --git a/tensorflow/python/keras/_impl/keras/callbacks.py b/tensorflow/python/keras/_impl/keras/callbacks.py index eb678c4d1d9..40a996a03f7 100644 --- a/tensorflow/python/keras/_impl/keras/callbacks.py +++ b/tensorflow/python/keras/_impl/keras/callbacks.py @@ -265,7 +265,7 @@ class ProgbarLogger(Callback): Arguments: count_mode: One of "steps" or "samples". Whether the progress bar should - count samples seens or steps (batches) seen. + count samples seen or steps (batches) seen. Raises: ValueError: In case of invalid `count_mode`. @@ -417,7 +417,7 @@ class ModelCheckpoint(Callback): self.epochs_since_last_save += 1 if self.epochs_since_last_save >= self.period: self.epochs_since_last_save = 0 - filepath = self.filepath.format(epoch=epoch, **logs) + filepath = self.filepath.format(epoch=epoch + 1, **logs) if self.save_best_only: current = logs.get(self.monitor) if current is None: @@ -427,7 +427,7 @@ class ModelCheckpoint(Callback): if self.monitor_op(current, self.best): if self.verbose > 0: print('Epoch %05d: %s improved from %0.5f to %0.5f,' - ' saving model to %s' % (epoch, self.monitor, self.best, + ' saving model to %s' % (epoch + 1, self.monitor, self.best, current, filepath)) self.best = current if self.save_weights_only: @@ -436,10 +436,11 @@ class ModelCheckpoint(Callback): self.model.save(filepath, overwrite=True) else: if self.verbose > 0: - print('Epoch %05d: %s did not improve' % (epoch, self.monitor)) + print('Epoch %05d: %s did not improve' % (epoch + 1, + self.monitor)) else: if self.verbose > 0: - print('Epoch %05d: saving model to %s' % (epoch, filepath)) + print('Epoch %05d: saving model to %s' % (epoch + 1, filepath)) if self.save_weights_only: self.model.save_weights(filepath, overwrite=True) else: @@ -519,14 +520,14 @@ class EarlyStopping(Callback): self.best = current self.wait = 0 else: + self.wait += 1 if self.wait >= self.patience: self.stopped_epoch = epoch self.model.stop_training = True - self.wait += 1 def on_train_end(self, logs=None): if self.stopped_epoch > 0 and self.verbose > 0: - print('Epoch %05d: early stopping' % (self.stopped_epoch)) + print('Epoch %05d: early stopping' % (self.stopped_epoch + 1)) class RemoteMonitor(Callback): diff --git a/tensorflow/python/keras/_impl/keras/callbacks_test.py b/tensorflow/python/keras/_impl/keras/callbacks_test.py index d9d7fb5a9fb..97a650a9920 100644 --- a/tensorflow/python/keras/_impl/keras/callbacks_test.py +++ b/tensorflow/python/keras/_impl/keras/callbacks_test.py @@ -203,12 +203,12 @@ class KerasCallbacksTest(test.TestCase): callbacks=cbks, epochs=4, verbose=1) - assert os.path.exists(filepath.format(epoch=1)) - assert os.path.exists(filepath.format(epoch=3)) - os.remove(filepath.format(epoch=1)) - os.remove(filepath.format(epoch=3)) - assert not os.path.exists(filepath.format(epoch=0)) - assert not os.path.exists(filepath.format(epoch=2)) + assert os.path.exists(filepath.format(epoch=2)) + assert os.path.exists(filepath.format(epoch=4)) + os.remove(filepath.format(epoch=2)) + os.remove(filepath.format(epoch=4)) + assert not os.path.exists(filepath.format(epoch=1)) + assert not os.path.exists(filepath.format(epoch=3)) # Invalid use: this will raise a warning but not an Exception. keras.callbacks.ModelCheckpoint( @@ -273,12 +273,12 @@ class KerasCallbacksTest(test.TestCase): stopper = keras.callbacks.EarlyStopping(monitor='acc', patience=patience) weights = model.get_weights() - hist = model.fit(data, labels, callbacks=[stopper], verbose=0) + hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20) assert len(hist.epoch) >= patience # This should allow training to go for at least `patience` epochs model.set_weights(weights) - hist = model.fit(data, labels, callbacks=[stopper], verbose=0) + hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20) assert len(hist.epoch) >= patience def test_RemoteMonitor(self): @@ -571,7 +571,6 @@ class KerasCallbacksTest(test.TestCase): loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) - tsb = keras.callbacks.TensorBoard( log_dir=temp_dir, histogram_freq=1, write_images=True, write_grads=True, batch_size=5) diff --git a/tensorflow/python/keras/_impl/keras/datasets/__init__.py b/tensorflow/python/keras/_impl/keras/datasets/__init__.py index 22afb6a5534..60db3766fbc 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/__init__.py +++ b/tensorflow/python/keras/_impl/keras/datasets/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """Keras datasets: utilities for downloading and pre-processing common datasets. + """ from __future__ import absolute_import from __future__ import division @@ -21,7 +22,7 @@ from __future__ import print_function from tensorflow.python.keras._impl.keras.datasets import boston_housing from tensorflow.python.keras._impl.keras.datasets import cifar10 from tensorflow.python.keras._impl.keras.datasets import cifar100 +from tensorflow.python.keras._impl.keras.datasets import fashion_mnist from tensorflow.python.keras._impl.keras.datasets import imdb from tensorflow.python.keras._impl.keras.datasets import mnist from tensorflow.python.keras._impl.keras.datasets import reuters - diff --git a/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py b/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py index e4f7fb9d212..4359be89280 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py +++ b/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py @@ -48,9 +48,10 @@ def load_data(path='boston_housing.npz', seed=113, test_split=0.2): f.close() np.random.seed(seed) - np.random.shuffle(x) - np.random.seed(seed) - np.random.shuffle(y) + indices = np.arrange(len(x)) + np.random.shuffle(indices) + x = x[indices] + y = y[indices] x_train = np.array(x[:int(len(x) * (1 - test_split))]) y_train = np.array(y[:int(len(x) * (1 - test_split))]) diff --git a/tensorflow/python/keras/_impl/keras/datasets/cifar10.py b/tensorflow/python/keras/_impl/keras/datasets/cifar10.py index 672249ff20f..7905da66c1e 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/cifar10.py +++ b/tensorflow/python/keras/_impl/keras/datasets/cifar10.py @@ -34,19 +34,18 @@ def load_data(): Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. """ dirname = 'cifar-10-batches-py' - origin = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' + origin = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' path = get_file(dirname, origin=origin, untar=True) num_train_samples = 50000 - x_train = np.zeros((num_train_samples, 3, 32, 32), dtype='uint8') - y_train = np.zeros((num_train_samples,), dtype='uint8') + x_train = np.empty((num_train_samples, 3, 32, 32), dtype='uint8') + y_train = np.empty((num_train_samples,), dtype='uint8') for i in range(1, 6): fpath = os.path.join(path, 'data_batch_' + str(i)) - data, labels = load_batch(fpath) - x_train[(i - 1) * 10000:i * 10000, :, :, :] = data - y_train[(i - 1) * 10000:i * 10000] = labels + (x_train[(i - 1) * 10000:i * 10000, :, :, :], + y_train[(i - 1) * 10000:i * 10000]) = load_batch(fpath) fpath = os.path.join(path, 'test_batch') x_test, y_test = load_batch(fpath) diff --git a/tensorflow/python/keras/_impl/keras/datasets/cifar100.py b/tensorflow/python/keras/_impl/keras/datasets/cifar100.py index 1be7483d273..b69c0724c58 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/cifar100.py +++ b/tensorflow/python/keras/_impl/keras/datasets/cifar100.py @@ -43,7 +43,7 @@ def load_data(label_mode='fine'): raise ValueError('label_mode must be one of "fine" "coarse".') dirname = 'cifar-100-python' - origin = 'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' + origin = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' path = get_file(dirname, origin=origin, untar=True) fpath = os.path.join(path, 'train') diff --git a/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py b/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py new file mode 100644 index 00000000000..17be684e4f8 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py @@ -0,0 +1,59 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fashion-MNIST dataset. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os +import numpy as np +from tensorflow.python.keras._impl.keras.utils.data_utils import get_file + + +def load_data(): + """Loads the Fashion-MNIST dataset. + + Returns: + Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. + """ + dirname = os.path.join('datasets', 'fashion-mnist') + base = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' + files = [ + 'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz', + 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz' + ] + + paths = [] + for given_file in files: + paths.append( + get_file(given_file, origin=base + given_file, cache_subdir=dirname)) + + with gzip.open(paths[0], 'rb') as lbpath: + y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8) + + with gzip.open(paths[1], 'rb') as imgpath: + x_train = np.frombuffer( + imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28) + + with gzip.open(paths[2], 'rb') as lbpath: + y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8) + + with gzip.open(paths[3], 'rb') as imgpath: + x_test = np.frombuffer( + imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28) + + return (x_train, y_train), (x_test, y_test) diff --git a/tensorflow/python/keras/_impl/keras/datasets/imdb.py b/tensorflow/python/keras/_impl/keras/datasets/imdb.py index 0db9d61f6d5..0e83473899c 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/imdb.py +++ b/tensorflow/python/keras/_impl/keras/datasets/imdb.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -65,23 +65,24 @@ def load_data(path='imdb.npz', have simply been skipped. """ path = get_file( - path, origin='https://s3.amazonaws.com/text-datasets/imdb.npz') + path, + origin='https://s3.amazonaws.com/text-datasets/imdb.npz', + file_hash='599dadb1135973df5b59232a0e9a887c') f = np.load(path) - x_train = f['x_train'] - labels_train = f['y_train'] - x_test = f['x_test'] - labels_test = f['y_test'] + x_train, labels_train = f['x_train'], f['y_train'] + x_test, labels_test = f['x_test'], f['y_test'] f.close() np.random.seed(seed) - np.random.shuffle(x_train) - np.random.seed(seed) - np.random.shuffle(labels_train) + indices = np.arrange(len(x_train)) + np.random.shuffle(indices) + x_train = x_train[indices] + labels_train = labels_train[indices] - np.random.seed(seed * 2) - np.random.shuffle(x_test) - np.random.seed(seed * 2) - np.random.shuffle(labels_test) + indices = np.arrange(len(x_test)) + np.random.shuffle(indices) + x_test = x_test[indices] + labels_test = labels_test[indices] xs = np.concatenate([x_train, x_test]) labels = np.concatenate([labels_train, labels_test]) diff --git a/tensorflow/python/keras/_impl/keras/datasets/mnist.py b/tensorflow/python/keras/_impl/keras/datasets/mnist.py index 02be5e2a407..e98f29537f4 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/mnist.py +++ b/tensorflow/python/keras/_impl/keras/datasets/mnist.py @@ -34,7 +34,9 @@ def load_data(path='mnist.npz'): Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. """ path = get_file( - path, origin='https://s3.amazonaws.com/img-datasets/mnist.npz') + path, + origin='https://s3.amazonaws.com/img-datasets/mnist.npz', + file_hash='8a61469f7ea1b51cbae51d4f78837e45') f = np.load(path) x_train = f['x_train'] y_train = f['y_train'] diff --git a/tensorflow/python/keras/_impl/keras/datasets/reuters.py b/tensorflow/python/keras/_impl/keras/datasets/reuters.py index c36bac5cc7d..d05eb0ef8ca 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/reuters.py +++ b/tensorflow/python/keras/_impl/keras/datasets/reuters.py @@ -64,15 +64,20 @@ def load_data(path='reuters.npz', have simply been skipped. """ path = get_file( - path, origin='https://s3.amazonaws.com/text-datasets/reuters.npz') + path, + origin='https://s3.amazonaws.com/text-datasets/reuters.npz', + file_hash='87aedbeb0cb229e378797a632c1997b6') npzfile = np.load(path) xs = npzfile['x'] labels = npzfile['y'] npzfile.close() np.random.seed(seed) - np.random.shuffle(xs) - np.random.seed(seed) + indices = np.arrange(len(xs)) + np.random.shuffle(indices) + xs = xs[indices] + labels = labels[indices] + np.random.shuffle(labels) if start_char is not None: @@ -129,7 +134,8 @@ def get_word_index(path='reuters_word_index.json'): """ path = get_file( path, - origin='https://s3.amazonaws.com/text-datasets/reuters_word_index.json') + origin='https://s3.amazonaws.com/text-datasets/reuters_word_index.json', + file_hash='4d44cc38712099c9e383dc6e5f11a921') f = open(path) data = json.load(f) f.close() diff --git a/tensorflow/python/keras/_impl/keras/engine/topology.py b/tensorflow/python/keras/_impl/keras/engine/topology.py index 1b7ddef9c45..4a7bb2e8389 100644 --- a/tensorflow/python/keras/_impl/keras/engine/topology.py +++ b/tensorflow/python/keras/_impl/keras/engine/topology.py @@ -36,6 +36,8 @@ from tensorflow.python.keras._impl.keras.utils import conv_utils from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.keras._impl.keras.utils.layer_utils import print_summary as print_layer_summary from tensorflow.python.layers import base as tf_base_layers +from tensorflow.python.layers import network as tf_network +from tensorflow.python.layers import utils as tf_layers_util from tensorflow.python.platform import tf_logging as logging @@ -485,7 +487,7 @@ class Layer(tf_base_layers.Layer): self._activity_regularizer = activity_regularizer -class InputLayer(tf_base_layers.InputLayer, Layer): +class InputLayer(tf_network.InputLayer, Layer): """Layer to be used as an entry point into a graph. It can either wrap an existing tensor (pass an `input_tensor` argument) @@ -636,7 +638,7 @@ def Input( # pylint: disable=invalid-name return outputs -class Network(tf_base_layers.Network, Layer): +class Network(tf_network.GraphNetwork, Layer): """A Network is a directed acyclic graph of layers. It is the topological form of a "model". A Model @@ -681,8 +683,8 @@ class Network(tf_base_layers.Network, Layer): for x in self.inputs: mask = x._keras_mask if hasattr(x, '_keras_mask') else None masks.append(mask) - mask_cache_key = (tf_base_layers._object_list_uid(self.inputs) + '_' + - tf_base_layers._object_list_uid(masks)) + mask_cache_key = (tf_layers_util.object_list_uid(self.inputs) + '_' + + tf_layers_util.object_list_uid(masks)) masks = [] for x in self.outputs: mask = x._keras_mask if hasattr(x, '_keras_mask') else None @@ -798,8 +800,8 @@ class Network(tf_base_layers.Network, Layer): else: kept_nodes = 0 for original_node_index, node in enumerate(layer._inbound_nodes): - node_key = tf_base_layers._make_node_key(layer.name, - original_node_index) + node_key = tf_network._make_node_key(layer.name, + original_node_index) if node_key in self._network_nodes: node_conversion_map[node_key] = kept_nodes kept_nodes += 1 @@ -809,8 +811,8 @@ class Network(tf_base_layers.Network, Layer): layer_config = layer.get_config() filtered_inbound_nodes = [] for original_node_index, node in enumerate(layer._inbound_nodes): - node_key = tf_base_layers._make_node_key(layer.name, - original_node_index) + node_key = tf_network._make_node_key(layer.name, + original_node_index) if node_key in self._network_nodes: # The node is relevant to the model: # add to filtered_inbound_nodes. @@ -834,8 +836,8 @@ class Network(tf_base_layers.Network, Layer): inbound_layer = node.inbound_layers[i] node_index = node.node_indices[i] tensor_index = node.tensor_indices[i] - node_key = tf_base_layers._make_node_key(inbound_layer.name, - node_index) + node_key = tf_network._make_node_key(inbound_layer.name, + node_index) new_node_index = node_conversion_map.get(node_key, 0) node_data.append( [inbound_layer.name, new_node_index, tensor_index, kwargs]) @@ -852,8 +854,8 @@ class Network(tf_base_layers.Network, Layer): model_inputs = [] for i in range(len(self._input_layers)): layer, node_index, tensor_index = self._input_coordinates[i] - node_key = tf_base_layers._make_node_key(layer.name, - node_index) + node_key = tf_network._make_node_key(layer.name, + node_index) if node_key not in self._network_nodes: continue new_node_index = node_conversion_map[node_key] @@ -862,8 +864,8 @@ class Network(tf_base_layers.Network, Layer): model_outputs = [] for i in range(len(self._output_layers)): layer, node_index, tensor_index = self._output_coordinates[i] - node_key = tf_base_layers._make_node_key(layer.name, - node_index) + node_key = tf_network._make_node_key(layer.name, + node_index) if node_key not in self._network_nodes: continue new_node_index = node_conversion_map[node_key] @@ -1422,6 +1424,31 @@ def preprocess_weights_for_loading(layer, weights[0] = np.transpose(weights[0], (3, 2, 0, 1)) if layer.__class__.__name__ == 'ConvLSTM2D': weights[1] = np.transpose(weights[1], (3, 2, 0, 1)) + + # convert the weights of CuDNNLSTM so that they could be loaded into LSTM + if layer.__class__.__name__ == 'LSTM': + # determine if we're loading a CuDNNLSTM layer from the number of bias + # weights: + # CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4) + units = weights[1].shape[0] + bias = weights[2] + if len(bias) == units * 8: + # reshape the kernels + kernels = np.split(weights[0], 4, axis=1) + kernels = [ + kernel.reshape(-1).reshape(kernel.shape, order='F') + for kernel in kernels + ] + weights[0] = np.concatenate(kernels, axis=1) + + # transpose the recurrent kernels + recurrent_kernels = np.split(weights[1], 4, axis=1) + recurrent_kernels = [kernel.T for kernel in recurrent_kernels] + weights[1] = np.concatenate(recurrent_kernels, axis=1) + + # split the bias into half and merge + weights[2] = bias[:units * 4] + bias[units * 4:] + return weights diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py index b1e48439ba0..b4205bf4a39 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training.py +++ b/tensorflow/python/keras/_impl/keras/engine/training.py @@ -71,6 +71,9 @@ def _standardize_input_data(data, if data is None: return [None for _ in range(len(names))] if isinstance(data, dict): + for key, value in data.items(): + if value.__class__.__name__ == 'DataFrame': + data[key] = value.values arrays = [] for name in names: if name not in data: @@ -78,6 +81,9 @@ def _standardize_input_data(data, '". Need data for each key in: ' + str(names)) arrays.append(data[name]) elif isinstance(data, list): + for key, value in enumerate(data): + if value.__class__.__name__ == 'DataFrame': + data[key] = value.values if len(data) != len(names): if data and hasattr(data[0], 'shape'): raise ValueError( @@ -100,6 +106,9 @@ def _standardize_input_data(data, ' Numpy arrays instead. ' 'The list you passed was: ' + str(data)[:200]) arrays = data + elif data.__class__.__name__ == 'DataFrame': + # test if data is a DataFrame, without pandas installed + arrays = data.values else: if not hasattr(data, 'shape'): raise TypeError('Error when checking model ' + exception_prefix + @@ -262,12 +271,13 @@ def _check_loss_and_target_compatibility(targets, loss_fns, output_shapes): is incompatible with an output. """ key_losses = { - 'mean_squared_error', 'binary_crossentropy', 'categorical_crossentropy' + losses.mean_squared_error, losses.binary_crossentropy, + losses.categorical_crossentropy } for y, loss, shape in zip(targets, loss_fns, output_shapes): if loss is None: continue - if loss.__name__ == 'categorical_crossentropy': + if loss is losses.categorical_crossentropy: if y.shape[-1] == 1: raise ValueError('You are passing a target array of shape ' + str( y.shape) + ' while using as loss `categorical_crossentropy`. ' @@ -277,14 +287,14 @@ def _check_loss_and_target_compatibility(targets, loss_fns, output_shapes): 'If your targets are integer classes, ' 'you can convert them to the expected format via:\n' '```\n' - 'from keras.utils.np_utils import to_categorical\n' + 'from keras.utils import to_categorical\n' 'y_binary = to_categorical(y_int)\n' '```\n' '\n' 'Alternatively, you can use the loss function ' '`sparse_categorical_crossentropy` instead, ' 'which does expect integer targets.') - if loss.__name__ in key_losses: + if loss in key_losses: for target_dim, out_dim in zip(y.shape[1:], shape[1:]): if out_dim is not None and target_dim != out_dim: raise ValueError('A target array with shape ' + str(y.shape) + @@ -367,7 +377,7 @@ def _make_batches(size, batch_size): """ num_batches = int(np.ceil(size / float(batch_size))) return [(i * batch_size, min(size, (i + 1) * batch_size)) - for i in range(0, num_batches)] + for i in range(num_batches)] def _slice_arrays(arrays, start=None, stop=None): @@ -575,7 +585,7 @@ class Model(Network): """Configures the model for training. Arguments: - optimizer: String (name of optimizer) or optimizer object. + optimizer: String (name of optimizer) or optimizer instance. See [optimizers](/optimizers). loss: String (name of objective function) or objective function. See [losses](/losses). @@ -614,9 +624,7 @@ class Model(Network): can specify them via the `target_tensors` argument. It can be a single tensor (for a single-output model), a list of tensors, or a dict mapping output names to target tensors. - **kwargs: When using the Theano/CNTK backends, these arguments - are passed into K.function. When using the TensorFlow backend, - these arguments are passed into `tf.Session.run`. + **kwargs: These arguments are passed to `tf.Session.run`. Raises: ValueError: In case of invalid arguments for @@ -627,6 +635,7 @@ class Model(Network): self.sample_weight_mode = sample_weight_mode self.loss = loss self.loss_weights = loss_weights + self.sample_weight_mode = sample_weight_mode # Prepare loss functions. if isinstance(loss, dict): @@ -936,9 +945,28 @@ class Model(Network): trainable_weights = self.trainable_weights self._collected_trainable_weights = trainable_weights + def _check_trainable_weights_consistency(self): + """Check trainable weights count consistency. + + This will raise a warning if `trainable_weights` and + `_collected_trainable_weights` are consistent (i.e. have the same + number of parameters). + Inconsistency will typically arise when one modifies `model.trainable` + without calling `model.compile` again. + """ + if not hasattr(self, '_collected_trainable_weights'): + return + + if len(self.trainable_weights) != len(self._collected_trainable_weights): + logging.warning( + 'Discrepancy between trainable weights and collected trainable' + ' weights, did you set `model.trainable` without calling' + ' `model.compile` after ?') + def _make_train_function(self): if not hasattr(self, 'train_function'): raise RuntimeError('You must compile your model before using it.') + self._check_trainable_weights_consistency() if self.train_function is None: inputs = (self._feed_inputs + self._feed_targets + @@ -1258,7 +1286,7 @@ class Model(Network): for i, batch_out in enumerate(batch_outs): unconcatenated_outs[i].append(batch_out) if verbose == 1: - progbar.update(step) + progbar.update(step + 1) if len(unconcatenated_outs) == 1: return np.concatenate(unconcatenated_outs[0], axis=0) return [ @@ -1313,9 +1341,13 @@ class Model(Network): """ num_samples = self._check_num_samples(ins, batch_size, steps, 'steps') outs = [] - if steps is not None: - if verbose == 1: + + if verbose == 1: + if steps is not None: progbar = Progbar(target=steps) + else: + progbar = Progbar(target=num_samples) + if steps is not None: for step in range(steps): batch_outs = f(ins) if isinstance(batch_outs, list): @@ -1329,7 +1361,7 @@ class Model(Network): outs.append(0.) outs[0] += batch_outs if verbose == 1: - progbar.update(step) + progbar.update(step + 1) for i in range(len(outs)): outs[i] /= steps else: @@ -1380,10 +1412,8 @@ class Model(Network): output_shapes = [] for output_shape, loss_fn in zip(self._feed_output_shapes, self._feed_loss_fns): - if loss_fn.__name__ == 'sparse_categorical_crossentropy': + if loss_fn is losses.sparse_categorical_crossentropy: output_shapes.append(output_shape[:-1] + (1,)) - elif getattr(losses, loss_fn.__name__, None) is None: - output_shapes.append(None) else: output_shapes.append(output_shape) x = _standardize_input_data( @@ -1451,58 +1481,76 @@ class Model(Network): """Trains the model for a fixed number of epochs (iterations on a dataset). Arguments: - x: Numpy array of training data, - or list of Numpy arrays if the model has multiple inputs. - If all inputs in the model are named, - you can also pass a dictionary - mapping input names to Numpy arrays. - y: Numpy array of target data, - or list of Numpy arrays if the model has multiple outputs. - If all outputs in the model are named, - you can also pass a dictionary - mapping output names to Numpy arrays. + x: Numpy array of training data (if the model has a single input), + or list of Numpy arrays (if the model has multiple inputs). + If input layers in the model are named, you can also pass a + dictionary mapping input names to Numpy arrays. + `x` can be `None` (default) if feeding from + TensorFlow data tensors. + y: Numpy array of target (label) data + (if the model has a single output), + or list of Numpy arrays (if the model has multiple outputs). + If output layers in the model are named, you can also pass a + dictionary mapping output names to Numpy arrays. + `y` can be `None` (default) if feeding from + TensorFlow data tensors. + Can be `None` (default) if feeding from framework-native tensors. batch_size: Integer or `None`. Number of samples per gradient update. If unspecified, it will default to 32. - epochs: Integer, the number of times to iterate - over the training data arrays. + epochs: Integer. Number of epochs to train the model. + An epoch is an iteration over the entire `x` and `y` + data provided. + Note that in conjunction with `initial_epoch`, + `epochs` is to be understood as "final epoch". + The model is not trained for a number of iterations + given by `epochs`, but merely until the epoch + of index `epochs` is reached. verbose: 0, 1, or 2. Verbosity mode. - 0 = silent, 1 = verbose, 2 = one log line per epoch. - callbacks: List of callbacks to be called during training. + 0 = silent, 1 = progress bar, 2 = one line per epoch. + callbacks: List of `keras.callbacks.Callback` instances. + List of callbacks to apply during training. See [callbacks](/callbacks). - validation_split: Float between 0 and 1: - fraction of the training data to be used as validation data. + validation_split: Float between 0 and 1. + Fraction of the training data to be used as validation data. The model will set apart this fraction of the training data, will not train on it, and will evaluate the loss and any model metrics on this data at the end of each epoch. - validation_data: Data on which to evaluate - the loss and any model metrics - at the end of each epoch. The model will not - be trained on this data. - This could be a tuple (x_val, y_val) - or a tuple (x_val, y_val, val_sample_weights). - shuffle: Boolean, whether to shuffle the training data - before each epoch. Has no effect when `steps_per_epoch` - is not `None`. - class_weight: Optional dictionary mapping - class indices (integers) to - a weight (float) to apply to the model's loss for the samples - from this class during training. - This can be useful to tell the model to "pay more attention" to - samples from an under-represented class. - sample_weight: Optional array of the same length as x, containing - weights to apply to the model's loss for each sample. - In the case of temporal data, you can pass a 2D array - with shape (samples, sequence_length), + The validation data is selected from the last samples + in the `x` and `y` data provided, before shuffling. + validation_data: tuple `(x_val, y_val)` or tuple + `(x_val, y_val, val_sample_weights)` on which to evaluate + the loss and any model metrics at the end of each epoch. + The model will not be trained on this data. + This will override `validation_split`. + shuffle: Boolean (whether to shuffle the training data + before each epoch) or str (for 'batch'). + 'batch' is a special option for dealing with the + limitations of HDF5 data; it shuffles in batch-sized chunks. + Has no effect when `steps_per_epoch` is not `None`. + class_weight: Optional dictionary mapping class indices (integers) + to a weight (float) value, used for weighting the loss function + (during training only). + This can be useful to tell the model to + "pay more attention" to samples from + an under-represented class. + sample_weight: Optional Numpy array of weights for + the training samples, used for weighting the loss function + (during training only). You can either pass a flat (1D) + Numpy array with the same length as the input samples + (1:1 mapping between weights and samples), + or in the case of temporal data, + you can pass a 2D array with shape + `(samples, sequence_length)`, to apply a different weight to every timestep of every sample. In this case you should make sure to specify - sample_weight_mode="temporal" in compile(). + `sample_weight_mode="temporal"` in `compile()`. initial_epoch: Epoch at which to start training - (useful for resuming a previous training run) + (useful for resuming a previous training run). steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the - next epoch. When training with Input Tensors such as + next epoch. When training with input tensors such as TensorFlow data tensors, the default `None` is equal to the number of unique samples in your dataset divided by the batch size, or 1 if that cannot be determined. @@ -1511,8 +1559,10 @@ class Model(Network): to validate before stopping. Returns: - A `History` instance. Its `history` attribute contains - all information collected during training. + A `History` object. Its `History.history` attribute is + a record of training loss values and metrics values + at successive epochs, as well as validation loss values + and validation metrics values (if applicable). Raises: ValueError: In case of mismatch between the provided input data @@ -1621,8 +1671,8 @@ class Model(Network): validation_steps=validation_steps) def evaluate(self, - x, - y, + x=None, + y=None, batch_size=None, verbose=1, sample_weight=None, @@ -1632,23 +1682,40 @@ class Model(Network): Computation is done in batches. Arguments: - x: Numpy array of test data, - or list of Numpy arrays if the model has multiple inputs. - If all inputs in the model are named, - you can also pass a dictionary - mapping input names to Numpy arrays. - y: Numpy array of target data, - or list of Numpy arrays if the model has multiple outputs. - If all outputs in the model are named, - you can also pass a dictionary - mapping output names to Numpy arrays. - batch_size: Integer. If unspecified, it will default to 32. - verbose: Verbosity mode, 0 or 1. - sample_weight: Array of weights to weight the contribution - of different samples to the loss and metrics. - steps: Total number of steps (batches of samples) + x: Numpy array of test data (if the model has a single input), + or list of Numpy arrays (if the model has multiple inputs). + If input layers in the model are named, you can also pass a + dictionary mapping input names to Numpy arrays. + `x` can be `None` (default) if feeding from + framework-native tensors (e.g. TensorFlow data tensors). + y: Numpy array of target (label) data + (if the model has a single output), + or list of Numpy arrays (if the model has multiple outputs). + If output layers in the model are named, you can also pass a + dictionary mapping output names to Numpy arrays. + `y` can be `None` (default) if feeding from + framework-native tensors (e.g. TensorFlow data tensors). + batch_size: Integer or `None`. + Number of samples per evaluation step. + If unspecified, `batch_size` will default to 32. + verbose: 0 or 1. Verbosity mode. + 0 = silent, 1 = progress bar. + sample_weight: Optional Numpy array of weights for + the test samples, used for weighting the loss function. + You can either pass a flat (1D) + Numpy array with the same length as the input samples + (1:1 mapping between weights and samples), + or in the case of temporal data, + you can pass a 2D array with shape + `(samples, sequence_length)`, + to apply a different weight to every timestep of every sample. + In this case you should make sure to specify + `sample_weight_mode="temporal"` in `compile()`. + steps: Integer or `None`. + Total number of steps (batches of samples) before declaring the evaluation round finished. - Ignored with the default value of `None`. + The default `None` is equal to the number of unique samples in + your dataset divided by the batch size. Returns: Scalar test loss (if the model has a single output and no metrics) @@ -1657,7 +1724,7 @@ class Model(Network): the display labels for the scalar outputs. Raises: - ValueError: In case of invalid argument values. + ValueError: In case of invalid arguments. """ # Backwards compatibility. if batch_size is None and steps is None: @@ -1877,8 +1944,7 @@ class Model(Network): Arguments: generator: A generator or an instance of Sequence (keras.utils.Sequence) - object in order to avoid duplicate data - when using multiprocessing. + object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either - a tuple (inputs, targets) - a tuple (inputs, targets, sample_weights). @@ -1889,8 +1955,8 @@ class Model(Network): steps_per_epoch: Total number of steps (batches of samples) to yield from `generator` before declaring one epoch finished and starting the next epoch. It should typically - be equal to the number of unique samples if your dataset - divided by the batch size. + be equal to the number of unique samples of your dataset + divided by the batch size. Not used if using `Sequence`. epochs: Integer, total number of iterations on the data. verbose: Verbosity mode, 0, 1, or 2. callbacks: List of callbacks to be called during training. @@ -1905,7 +1971,7 @@ class Model(Network): for the class. max_queue_size: Maximum size for the generator queue workers: Maximum number of processes to spin up - when using process based threading + when using process-based threading. use_multiprocessing: If True, use process based threading. Note that because this implementation relies on multiprocessing, @@ -1914,8 +1980,8 @@ class Model(Network): as they can't be passed easily to children processes. shuffle: Whether to shuffle the data at the beginning of each - epoch. Only used with instances of `Sequence` ( - keras.utils.Sequence). + epoch. Only used with instances of `Sequence` + (`keras.utils.Sequence`). initial_epoch: Epoch at which to start training (useful for resuming a previous training run) **kwargs: support for legacy arguments. @@ -1944,7 +2010,7 @@ class Model(Network): ValueError: In case the generator yields data in an invalid format. """ - # Legacy support + # Legacy support if 'max_q_size' in kwargs: max_queue_size = kwargs.pop('max_q_size') logging.warning('The argument `max_q_size` has been renamed ' @@ -2025,6 +2091,8 @@ class Model(Network): ' and multiple workers may duplicate your data.' ' Please consider using the`keras.utils.Sequence' ' class.')) + if is_sequence: + steps_per_epoch = len(generator) enqueuer = None try: @@ -2142,13 +2210,14 @@ class Model(Network): generator: Generator yielding tuples (inputs, targets) or (inputs, targets, sample_weights) or an instance of Sequence (keras.utils.Sequence) - object in order to avoid duplicate data - when using multiprocessing. + object in order to avoid duplicate data + when using multiprocessing. steps: Total number of steps (batches of samples) to yield from `generator` before stopping. + Not used if using `Sequence`. max_queue_size: maximum size for the generator queue workers: maximum number of processes to spin up - when using process based threading + when using process-based threading. use_multiprocessing: if True, use process based threading. Note that because this implementation relies on multiprocessing, @@ -2194,6 +2263,8 @@ class Model(Network): ' and multiple workers may duplicate your data.' ' Please consider using the`keras.utils.Sequence' ' class.')) + if is_sequence: + steps = len(generator) enqueuer = None try: @@ -2273,8 +2344,9 @@ class Model(Network): steps: Total number of steps (batches of samples) to yield from `generator` before stopping. max_queue_size: Maximum size for the generator queue. + Not used if using `Sequence`. workers: Maximum number of processes to spin up - when using process based threading + when using process-based threading. use_multiprocessing: If `True`, use process based threading. Note that because this implementation relies on multiprocessing, @@ -2315,6 +2387,8 @@ class Model(Network): ' and multiple workers may duplicate your data.' ' Please consider using the`keras.utils.Sequence' ' class.')) + if is_sequence: + steps = len(generator) enqueuer = None try: diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py index bc9ad6693e5..e2a06e8e778 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py @@ -640,6 +640,19 @@ class LossMaskingTest(test.TestCase): class TestDynamicTrainability(test.TestCase): + def test_trainable_warning(self): + with self.test_session(): + x = np.random.random((5, 3)) + y = np.random.random((5, 2)) + + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_dim=3)) + model.trainable = False + model.compile('rmsprop', 'mse') + model.trainable = True + model.train_on_batch(x, y) + self.assertRaises(Warning) + def test_trainable_argument(self): with self.test_session(): x = np.random.random((5, 3)) diff --git a/tensorflow/python/keras/_impl/keras/integration_test.py b/tensorflow/python/keras/_impl/keras/integration_test.py index 871a8c73298..15c3d14727a 100644 --- a/tensorflow/python/keras/_impl/keras/integration_test.py +++ b/tensorflow/python/keras/_impl/keras/integration_test.py @@ -22,8 +22,8 @@ import numpy as np from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils -from tensorflow.python.layers import base as tf_base_layers from tensorflow.python.layers import core as tf_core_layers +from tensorflow.python.layers import network as tf_network_layers from tensorflow.python.ops import nn from tensorflow.python.platform import test @@ -275,7 +275,7 @@ class KerasIntegrationTest(test.TestCase): y_train = keras.utils.to_categorical(y_train) y_test = keras.utils.to_categorical(y_test) - inputs = tf_base_layers.Input(shape=(10,)) + inputs = tf_network_layers.Input(shape=(10,)) x = tf_core_layers.Dense(32, activation=nn.relu)(inputs) outputs = tf_core_layers.Dense(2, activation=nn.softmax)(x) model = keras.models.Model(inputs, outputs) diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional.py b/tensorflow/python/keras/_impl/keras/layers/convolutional.py index ce96bc66f7c..1cbae912631 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional.py +++ b/tensorflow/python/keras/_impl/keras/layers/convolutional.py @@ -793,6 +793,7 @@ class SeparableConv2D(tf_convolutional_layers.SeparableConv2D, Layer): strides=(1, 1), padding='valid', data_format=None, + dilation_rate=1, depth_multiplier=1, activation=None, use_bias=True, @@ -815,6 +816,7 @@ class SeparableConv2D(tf_convolutional_layers.SeparableConv2D, Layer): strides=strides, padding=padding, data_format=data_format, + dilation_rate=dilation_rate, activation=activations.get(activation), use_bias=use_bias, depthwise_initializer=initializers.get(depthwise_initializer), @@ -831,30 +833,42 @@ class SeparableConv2D(tf_convolutional_layers.SeparableConv2D, Layer): def get_config(self): config = { - 'filters': self.filters, - 'kernel_size': self.kernel_size, - 'strides': self.strides, - 'padding': self.padding, - 'data_format': self.data_format, - 'activation': activations.serialize(self.activation), - 'use_bias': self.use_bias, - 'depthwise_initializer': initializers.serialize( - self.depthwise_initializer), - 'pointwise_initializer': initializers.serialize( - self.pointwise_initializer), - 'bias_initializer': initializers.serialize(self.bias_initializer), - 'depthwise_regularizer': regularizers.serialize( - self.depthwise_regularizer), - 'pointwise_regularizer': regularizers.serialize( - self.pointwise_regularizer), - 'bias_regularizer': regularizers.serialize(self.bias_regularizer), + 'filters': + self.filters, + 'kernel_size': + self.kernel_size, + 'strides': + self.strides, + 'padding': + self.padding, + 'data_format': + self.data_format, + 'dilation_rate': + self.dilation_rate, + 'activation': + activations.serialize(self.activation), + 'use_bias': + self.use_bias, + 'depthwise_initializer': + initializers.serialize(self.depthwise_initializer), + 'pointwise_initializer': + initializers.serialize(self.pointwise_initializer), + 'bias_initializer': + initializers.serialize(self.bias_initializer), + 'depthwise_regularizer': + regularizers.serialize(self.depthwise_regularizer), + 'pointwise_regularizer': + regularizers.serialize(self.pointwise_regularizer), + 'bias_regularizer': + regularizers.serialize(self.bias_regularizer), 'activity_regularizer': regularizers.serialize(self.activity_regularizer), - 'depthwise_constraint': constraints.serialize( - self.depthwise_constraint), - 'pointwise_constraint': constraints.serialize( - self.pointwise_constraint), - 'bias_constraint': constraints.serialize(self.bias_constraint) + 'depthwise_constraint': + constraints.serialize(self.depthwise_constraint), + 'pointwise_constraint': + constraints.serialize(self.pointwise_constraint), + 'bias_constraint': + constraints.serialize(self.bias_constraint) } base_config = super(SeparableConv2D, self).get_config() return dict(list(base_config.items()) + list(config.items())) diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py index 2335bd4df02..c88122ce188 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py +++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py @@ -536,7 +536,7 @@ class ConvLSTM2D(ConvRecurrent2D): conv_out = K.bias_add(conv_out, b, data_format=self.data_format) return conv_out - def reccurent_conv(self, x, w): + def recurrent_conv(self, x, w): conv_out = K.conv2d( x, w, strides=(1, 1), padding='same', data_format=self.data_format) return conv_out @@ -556,10 +556,10 @@ class ConvLSTM2D(ConvRecurrent2D): inputs * dp_mask[2], self.kernel_c, self.bias_c, padding=self.padding) x_o = self.input_conv( inputs * dp_mask[3], self.kernel_o, self.bias_o, padding=self.padding) - h_i = self.reccurent_conv(h_tm1 * rec_dp_mask[0], self.recurrent_kernel_i) - h_f = self.reccurent_conv(h_tm1 * rec_dp_mask[1], self.recurrent_kernel_f) - h_c = self.reccurent_conv(h_tm1 * rec_dp_mask[2], self.recurrent_kernel_c) - h_o = self.reccurent_conv(h_tm1 * rec_dp_mask[3], self.recurrent_kernel_o) + h_i = self.recurrent_conv(h_tm1 * rec_dp_mask[0], self.recurrent_kernel_i) + h_f = self.recurrent_conv(h_tm1 * rec_dp_mask[1], self.recurrent_kernel_f) + h_c = self.recurrent_conv(h_tm1 * rec_dp_mask[2], self.recurrent_kernel_c) + h_o = self.recurrent_conv(h_tm1 * rec_dp_mask[3], self.recurrent_kernel_o) i = self.recurrent_activation(x_i + h_i) f = self.recurrent_activation(x_f + h_f) diff --git a/tensorflow/python/keras/_impl/keras/layers/core.py b/tensorflow/python/keras/_impl/keras/layers/core.py index b2e0e7b8eeb..517129fab05 100644 --- a/tensorflow/python/keras/_impl/keras/layers/core.py +++ b/tensorflow/python/keras/_impl/keras/layers/core.py @@ -52,7 +52,7 @@ class Masking(Layer): Example: Consider a Numpy data array `x` of shape `(samples, timesteps, features)`, - to be fed to a LSTM layer. + to be fed to an LSTM layer. You want to mask timestep #3 and #5 because you lack data for these timesteps. You can: @@ -121,7 +121,11 @@ class Dropout(tf_core_layers.Dropout, Layer): return output def get_config(self): - config = {'rate': self.rate} + config = { + 'rate': self.rate, + 'noise_shape': self.noise_shape, + 'seed': self.seed + } base_config = super(Dropout, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -383,20 +387,18 @@ class Reshape(Layer): def _compute_output_shape(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape).as_list() - output_shape = [input_shape[0]] - output_shape += self._fix_unknown_dimension(input_shape[1:], - self.target_shape) + if None in input_shape[1:]: + output_shape = [input_shape[0]] + # input shape (partially) unknown? replace -1's with None's + output_shape += tuple(s if s != -1 else None for s in self.target_shape) + else: + output_shape = [input_shape[0]] + output_shape += self._fix_unknown_dimension(input_shape[1:], + self.target_shape) return tensor_shape.TensorShape(output_shape) def call(self, inputs): - # In case the target shape is not fully defined, - # we need access to the shape of x. - target_shape = self.target_shape - if -1 in target_shape: - # target shape not fully defined - target_shape = self._compute_output_shape(inputs.get_shape()) - target_shape = target_shape.as_list()[1:] - return K.reshape(inputs, (-1,) + tuple(target_shape)) + return K.reshape(inputs, (K.shape(inputs)[0],) + self.target_shape) def get_config(self): config = {'target_shape': self.target_shape} @@ -595,6 +597,7 @@ class Lambda(Layer): @classmethod def from_config(cls, config, custom_objects=None): + config = config.copy() globs = globals() if custom_objects: globs = dict(list(globs.items()) + list(custom_objects.items())) diff --git a/tensorflow/python/keras/_impl/keras/layers/core_test.py b/tensorflow/python/keras/_impl/keras/layers/core_test.py index 9cdebd375c8..dd768dc268e 100644 --- a/tensorflow/python/keras/_impl/keras/layers/core_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/core_test.py @@ -111,6 +111,12 @@ class CoreLayersTest(test.TestCase): kwargs={'target_shape': (1, -1)}, input_shape=(3, 2, 4)) + with self.test_session(): + testing_utils.layer_test( + keras.layers.Reshape, + kwargs={'target_shape': (-1, 1)}, + input_shape=(None, None, 2)) + def test_permute(self): with self.test_session(): testing_utils.layer_test( diff --git a/tensorflow/python/keras/_impl/keras/layers/merge.py b/tensorflow/python/keras/_impl/keras/layers/merge.py index 84b65d87c2f..888be273693 100644 --- a/tensorflow/python/keras/_impl/keras/layers/merge.py +++ b/tensorflow/python/keras/_impl/keras/layers/merge.py @@ -299,11 +299,26 @@ class Maximum(_Merge): return output +class Minimum(_Merge): + """Layer that computes the minimum (element-wise) a list of inputs. + + It takes as input a list of tensors, + all of the same shape, and returns + a single tensor (also of the same shape). + """ + + def _merge_function(self, inputs): + output = inputs[0] + for i in range(1, len(inputs)): + output = K.minimum(output, inputs[i]) + return output + + class Concatenate(_Merge): """Layer that concatenates a list of inputs. It takes as input a list of tensors, - all of the same shape expect for the concatenation axis, + all of the same shape except for the concatenation axis, and returns a single tensor, the concatenation of all inputs. Arguments: @@ -375,9 +390,8 @@ class Concatenate(_Merge): masks = [] for input_i, mask_i in zip(inputs, mask): if mask_i is None: - # Input is unmasked. Append all 1s to masks, - # but cast it to bool first - masks.append(K.cast(K.ones_like(input_i), 'bool')) + # Input is unmasked. Append all 1s to masks + masks.append(K.ones_like(input_i, dtype='bool')) elif K.ndim(mask_i) < K.ndim(input_i): # Mask is smaller than the input, expand it masks.append(K.expand_dims(mask_i)) @@ -584,6 +598,19 @@ def maximum(inputs, **kwargs): return Maximum(**kwargs)(inputs) +def minimum(inputs, **kwargs): + """Functional interface to the `Minimum` layer. + + Arguments: + inputs: A list of input tensors (at least 2). + **kwargs: Standard layer keyword arguments. + + Returns: + A tensor, the element-wise minimum of the inputs. + """ + return Minimum(**kwargs)(inputs) + + def concatenate(inputs, axis=-1, **kwargs): """Functional interface to the `Concatenate` layer. diff --git a/tensorflow/python/keras/_impl/keras/layers/merge_test.py b/tensorflow/python/keras/_impl/keras/layers/merge_test.py index a5746582791..1f34c367e4b 100644 --- a/tensorflow/python/keras/_impl/keras/layers/merge_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/merge_test.py @@ -116,6 +116,20 @@ class MergeLayersTest(test.TestCase): self.assertEqual(out.shape, (2, 4, 5)) self.assertAllClose(out, np.maximum(x1, x2), atol=1e-4) + def test_merge_minimum(self): + with self.test_session(): + i1 = keras.layers.Input(shape=(4, 5)) + i2 = keras.layers.Input(shape=(4, 5)) + o = keras.layers.minimum([i1, i2]) + self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) + model = keras.models.Model([i1, i2], o) + + x1 = np.random.random((2, 4, 5)) + x2 = np.random.random((2, 4, 5)) + out = model.predict([x1, x2]) + self.assertEqual(out.shape, (2, 4, 5)) + self.assertAllClose(out, np.minimum(x1, x2), atol=1e-4) + def test_merge_concatenate(self): with self.test_session(): i1 = keras.layers.Input(shape=(4, 5)) diff --git a/tensorflow/python/keras/_impl/keras/layers/pooling.py b/tensorflow/python/keras/_impl/keras/layers/pooling.py index e773e396796..afe4ebfdc53 100644 --- a/tensorflow/python/keras/_impl/keras/layers/pooling.py +++ b/tensorflow/python/keras/_impl/keras/layers/pooling.py @@ -367,7 +367,7 @@ class GlobalAveragePooling1D(_GlobalPooling1D): Output shape: 2D tensor with shape: - `(batch_size, channels)` + `(batch_size, features)` """ def call(self, inputs): @@ -382,7 +382,7 @@ class GlobalMaxPooling1D(_GlobalPooling1D): Output shape: 2D tensor with shape: - `(batch_size, channels)` + `(batch_size, features)` """ def call(self, inputs): diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py index 2bc74d5f807..8df1840b4cb 100644 --- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py +++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py @@ -756,6 +756,8 @@ class RNN(Layer): @property def trainable_weights(self): + if not self.trainable: + return [] if isinstance(self.cell, Layer): return self.cell.trainable_weights return [] @@ -763,6 +765,8 @@ class RNN(Layer): @property def non_trainable_weights(self): if isinstance(self.cell, Layer): + if not self.trainable: + return self.cell.weights return self.cell.non_trainable_weights return [] @@ -1048,7 +1052,6 @@ class SimpleRNN(RNN): unroll=unroll, activity_regularizer=regularizers.get(activity_regularizer), **kwargs) - # self.activity_regularizer = regularizers.get(activity_regularizer) def call(self, inputs, mask=None, training=None, initial_state=None): self.cell._generate_dropout_mask(inputs, training=training) @@ -1114,36 +1117,25 @@ class SimpleRNN(RNN): def get_config(self): config = { - 'units': - self.units, - 'activation': - activations.serialize(self.activation), - 'use_bias': - self.use_bias, - 'kernel_initializer': - initializers.serialize(self.kernel_initializer), + 'units': self.units, + 'activation': activations.serialize(self.activation), + 'use_bias': self.use_bias, + 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'recurrent_initializer': initializers.serialize(self.recurrent_initializer), - 'bias_initializer': - initializers.serialize(self.bias_initializer), - 'kernel_regularizer': - regularizers.serialize(self.kernel_regularizer), + 'bias_initializer': initializers.serialize(self.bias_initializer), + 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer), - 'bias_regularizer': - regularizers.serialize(self.bias_regularizer), + 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 'activity_regularizer': regularizers.serialize(self.activity_regularizer), - 'kernel_constraint': - constraints.serialize(self.kernel_constraint), + 'kernel_constraint': constraints.serialize(self.kernel_constraint), 'recurrent_constraint': constraints.serialize(self.recurrent_constraint), - 'bias_constraint': - constraints.serialize(self.bias_constraint), - 'dropout': - self.dropout, - 'recurrent_dropout': - self.recurrent_dropout + 'bias_constraint': constraints.serialize(self.bias_constraint), + 'dropout': self.dropout, + 'recurrent_dropout': self.recurrent_dropout } base_config = super(SimpleRNN, self).get_config() del base_config['cell'] @@ -1597,40 +1589,28 @@ class GRU(RNN): def get_config(self): config = { - 'units': - self.units, - 'activation': - activations.serialize(self.activation), + 'units': self.units, + 'activation': activations.serialize(self.activation), 'recurrent_activation': activations.serialize(self.recurrent_activation), - 'use_bias': - self.use_bias, - 'kernel_initializer': - initializers.serialize(self.kernel_initializer), + 'use_bias': self.use_bias, + 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'recurrent_initializer': initializers.serialize(self.recurrent_initializer), - 'bias_initializer': - initializers.serialize(self.bias_initializer), - 'kernel_regularizer': - regularizers.serialize(self.kernel_regularizer), + 'bias_initializer': initializers.serialize(self.bias_initializer), + 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer), - 'bias_regularizer': - regularizers.serialize(self.bias_regularizer), + 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 'activity_regularizer': regularizers.serialize(self.activity_regularizer), - 'kernel_constraint': - constraints.serialize(self.kernel_constraint), + 'kernel_constraint': constraints.serialize(self.kernel_constraint), 'recurrent_constraint': constraints.serialize(self.recurrent_constraint), - 'bias_constraint': - constraints.serialize(self.bias_constraint), - 'dropout': - self.dropout, - 'recurrent_dropout': - self.recurrent_dropout, - 'implementation': - self.implementation + 'bias_constraint': constraints.serialize(self.bias_constraint), + 'dropout': self.dropout, + 'recurrent_dropout': self.recurrent_dropout, + 'implementation': self.implementation } base_config = super(GRU, self).get_config() del base_config['cell'] @@ -2125,42 +2105,29 @@ class LSTM(RNN): def get_config(self): config = { - 'units': - self.units, - 'activation': - activations.serialize(self.activation), + 'units': self.units, + 'activation': activations.serialize(self.activation), 'recurrent_activation': activations.serialize(self.recurrent_activation), - 'use_bias': - self.use_bias, - 'kernel_initializer': - initializers.serialize(self.kernel_initializer), + 'use_bias': self.use_bias, + 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'recurrent_initializer': initializers.serialize(self.recurrent_initializer), - 'bias_initializer': - initializers.serialize(self.bias_initializer), - 'unit_forget_bias': - self.unit_forget_bias, - 'kernel_regularizer': - regularizers.serialize(self.kernel_regularizer), + 'bias_initializer': initializers.serialize(self.bias_initializer), + 'unit_forget_bias': self.unit_forget_bias, + 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer), - 'bias_regularizer': - regularizers.serialize(self.bias_regularizer), + 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 'activity_regularizer': regularizers.serialize(self.activity_regularizer), - 'kernel_constraint': - constraints.serialize(self.kernel_constraint), + 'kernel_constraint': constraints.serialize(self.kernel_constraint), 'recurrent_constraint': constraints.serialize(self.recurrent_constraint), - 'bias_constraint': - constraints.serialize(self.bias_constraint), - 'dropout': - self.dropout, - 'recurrent_dropout': - self.recurrent_dropout, - 'implementation': - self.implementation + 'bias_constraint': constraints.serialize(self.bias_constraint), + 'dropout': self.dropout, + 'recurrent_dropout': self.recurrent_dropout, + 'implementation': self.implementation } base_config = super(LSTM, self).get_config() del base_config['cell'] diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py index b1f89a30bb3..7dc4c1db9b4 100644 --- a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py @@ -359,19 +359,38 @@ class RNNTest(test.TestCase): layer.build((None, None, 5)) # Test regularization losses - assert len(layer.losses) == 1 + self.assertEqual(len(layer.losses), 1) # Test weights - assert len(layer.trainable_weights) == 6 + self.assertEqual(len(layer.trainable_weights), 6) cells[0].trainable = False - assert len(layer.trainable_weights) == 3 - assert len(layer.non_trainable_weights) == 3 + self.assertEqual(len(layer.trainable_weights), 3) + self.assertEqual(len(layer.non_trainable_weights), 3) # Test `get_losses_for` x = keras.Input((None, 5)) y = keras.backend.sum(x) cells[0].add_loss(y, inputs=x) - assert layer.get_losses_for(x) == [y] + self.assertEqual(layer.get_losses_for(x), [y]) + + def test_rnn_dynamic_trainability(self): + layer_class = keras.layers.SimpleRNN + embedding_dim = 4 + units = 3 + + layer = layer_class(units) + layer.build((None, None, embedding_dim)) + self.assertEqual(len(layer.weights), 3) + self.assertEqual(len(layer.trainable_weights), 3) + self.assertEqual(len(layer.non_trainable_weights), 0) + layer.trainable = False + self.assertEqual(len(layer.weights), 3) + self.assertEqual(len(layer.trainable_weights), 0) + self.assertEqual(len(layer.non_trainable_weights), 3) + layer.trainable = True + self.assertEqual(len(layer.weights), 3) + self.assertEqual(len(layer.trainable_weights), 3) + self.assertEqual(len(layer.non_trainable_weights), 0) if __name__ == '__main__': diff --git a/tensorflow/python/keras/_impl/keras/layers/wrappers.py b/tensorflow/python/keras/_impl/keras/layers/wrappers.py index a0cca9dc2fc..aefa5a1c020 100644 --- a/tensorflow/python/keras/_impl/keras/layers/wrappers.py +++ b/tensorflow/python/keras/_impl/keras/layers/wrappers.py @@ -26,7 +26,7 @@ from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine import Layer from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg -from tensorflow.python.layers import base as tf_base_layers +from tensorflow.python.layers import utils as tf_layers_util class Wrapper(Layer): @@ -77,7 +77,7 @@ class Wrapper(Layer): # get the updates from the inner layer. inner_inputs = inputs if inputs is not None: - uid = tf_base_layers._object_list_uid(inputs) + uid = tf_layers_util.object_list_uid(inputs) if uid in self._input_map: inner_inputs = self._input_map[uid] @@ -97,10 +97,6 @@ class Wrapper(Layer): return losses + super(Wrapper, self).get_losses_for(None) return super(Wrapper, self).get_losses_for(inputs) - @property - def constraints(self): - return self.layer.constraints - def get_weights(self): return self.layer.get_weights() @@ -227,7 +223,7 @@ class TimeDistributed(Wrapper): input_length = K.shape(inputs)[1] # Shape: (num_samples * timesteps, ...). And track the # transformation in self._input_map. - input_uid = tf_base_layers._object_list_uid(inputs) + input_uid = tf_layers_util.object_list_uid(inputs) inputs = K.reshape(inputs, (-1,) + input_shape[2:]) self._input_map[input_uid] = inputs # (num_samples * timesteps, ...) @@ -340,7 +336,8 @@ class Bidirectional(Wrapper): output = [y, y_rev] # Properly set learning phase - if 0 < self.layer.dropout + self.layer.recurrent_dropout: + if (getattr(y, '_uses_learning_phase', False) or + getattr(y_rev, '_uses_learning_phase', False)): if self.merge_mode is None: for out in output: out._uses_learning_phase = True diff --git a/tensorflow/python/keras/_impl/keras/losses.py b/tensorflow/python/keras/_impl/keras/losses.py index 7c6b304622a..19212aeee8c 100644 --- a/tensorflow/python/keras/_impl/keras/losses.py +++ b/tensorflow/python/keras/_impl/keras/losses.py @@ -22,6 +22,7 @@ import six from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object def mean_squared_error(y_true, y_pred): @@ -91,7 +92,7 @@ def poisson(y_true, y_pred): def cosine_proximity(y_true, y_pred): y_true = K.l2_normalize(y_true, axis=-1) y_pred = K.l2_normalize(y_pred, axis=-1) - return -K.mean(y_true * y_pred, axis=-1) + return -K.sum(y_true * y_pred, axis=-1) # Aliases. @@ -105,7 +106,7 @@ cosine = cosine_proximity def serialize(loss): - return loss.__name__ + return serialize_keras_object(loss) def deserialize(name, custom_objects=None): @@ -122,6 +123,8 @@ def get(identifier): if isinstance(identifier, six.string_types): identifier = str(identifier) return deserialize(identifier) + if isinstance(identifier, dict): + return deserialize(identifier) elif callable(identifier): return identifier else: diff --git a/tensorflow/python/keras/_impl/keras/losses_test.py b/tensorflow/python/keras/_impl/keras/losses_test.py index b295356ec19..1884c0fdca7 100644 --- a/tensorflow/python/keras/_impl/keras/losses_test.py +++ b/tensorflow/python/keras/_impl/keras/losses_test.py @@ -18,11 +18,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os +import shutil + import numpy as np from tensorflow.python.keras._impl import keras from tensorflow.python.platform import test +try: + import h5py # pylint:disable=g-import-not-at-top +except ImportError: + h5py = None ALL_LOSSES = [keras.losses.mean_squared_error, keras.losses.mean_absolute_error, @@ -39,6 +46,20 @@ ALL_LOSSES = [keras.losses.mean_squared_error, keras.losses.categorical_hinge] +class _MSEMAELoss(object): + """Loss function with internal state, for testing serialization code.""" + + def __init__(self, mse_fraction): + self.mse_fraction = mse_fraction + + def __call__(self, y_true, y_pred): + return (self.mse_fraction * keras.losses.mse(y_true, y_pred) + + (1 - self.mse_fraction) * keras.losses.mae(y_true, y_pred)) + + def get_config(self): + return {'mse_fraction': self.mse_fraction} + + class KerasLossesTest(test.TestCase): def test_objective_shapes_3d(self): @@ -83,6 +104,39 @@ class KerasLossesTest(test.TestCase): loss = keras.backend.eval(keras.losses.categorical_hinge(y_true, y_pred)) self.assertAllClose(expected_loss, np.mean(loss)) + def test_serializing_loss_class(self): + orig_loss_class = _MSEMAELoss(0.3) + with keras.utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}): + serialized = keras.losses.serialize(orig_loss_class) + + with keras.utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}): + deserialized = keras.losses.deserialize(serialized) + assert isinstance(deserialized, _MSEMAELoss) + assert deserialized.mse_fraction == 0.3 + + def test_serializing_model_with_loss_class(self): + tmpdir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, tmpdir) + model_filename = os.path.join(tmpdir, 'custom_loss.h5') + + with self.test_session(): + with keras.utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}): + loss = _MSEMAELoss(0.3) + inputs = keras.layers.Input((2,)) + outputs = keras.layers.Dense(1, name='model_output')(inputs) + model = keras.models.Model(inputs, outputs) + model.compile(optimizer='sgd', loss={'model_output': loss}) + model.fit(np.random.rand(256, 2), np.random.rand(256, 1)) + + if h5py is None: + return + + model.save(model_filename) + + with keras.utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}): + loaded_model = keras.models.load_model(model_filename) + loaded_model.predict(np.random.rand(128, 2)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/_impl/keras/models.py b/tensorflow/python/keras/_impl/keras/models.py index 06941e4bac0..ba202827ce3 100644 --- a/tensorflow/python/keras/_impl/keras/models.py +++ b/tensorflow/python/keras/_impl/keras/models.py @@ -31,6 +31,7 @@ from tensorflow.python.keras._impl.keras import layers as layer_module from tensorflow.python.keras._impl.keras import optimizers from tensorflow.python.keras._impl.keras.engine import topology from tensorflow.python.keras._impl.keras.engine.topology import Input +from tensorflow.python.keras._impl.keras.engine.topology import InputLayer from tensorflow.python.keras._impl.keras.engine.topology import Layer from tensorflow.python.keras._impl.keras.engine.topology import TFBaseLayer from tensorflow.python.keras._impl.keras.engine.training import Model @@ -456,38 +457,48 @@ class Sequential(Model): 'an instance of class Layer. ' 'Found: ' + str(layer)) if not self.outputs: - # first layer in model: check that it is an input layer - if not layer._inbound_nodes: - # create an input layer - if not hasattr(layer, '_batch_input_shape'): - raise ValueError('The first layer in a ' - 'Sequential model must ' - 'get an `input_shape` or ' - '`batch_input_shape` argument.') + # First layer in model: check that it is an input layer. + if not isinstance(layer, InputLayer): + # Create an input layer. + # First, we need to infer its expected input shape and dtype. + if isinstance(layer, (Model, Sequential)): + # We were passed a model as first layer. + # This requires a specific way to figure out the + # input shape and dtype. + if not layer.layers: + raise ValueError('Cannot add an empty model ' + 'to a `Sequential` model.') + # In case of nested models: recover the first layer + # of the deepest model to infer input shape and dtype. + first_layer = layer.layers[0] + while isinstance(first_layer, (Model, Sequential)): + first_layer = first_layer.layers[0] + batch_shape = first_layer._batch_input_shape + dtype = first_layer.dtype + else: + # We were passed a regular layer, and it should + # know about its input shape. Otherwise, that's an error. + if not hasattr(layer, '_batch_input_shape'): + raise ValueError('The first layer in a ' + 'Sequential model must ' + 'get an `input_shape` argument.') + batch_shape = layer._batch_input_shape + dtype = layer.dtype # Instantiate the input layer. x = Input( - batch_shape=layer._batch_input_shape, - dtype=layer.dtype, - name=layer.name + '_input') + batch_shape=batch_shape, dtype=dtype, name=layer.name + '_input') # This will build the current layer # and create the node connecting the current layer # to the input layer we just created. layer(x) - if len(layer._inbound_nodes) != 1: - raise ValueError('A layer added to a Sequential model must ' - 'not already be connected somewhere else. ' - 'Model received layer ' + layer.name + ' which has ' + - str(len(layer._inbound_nodes)) + - ' pre-existing inbound connections.') - - if len(layer._inbound_nodes[0].output_tensors) != 1: + if len(layer.inbound_nodes[-1].output_tensors) != 1: raise ValueError('All layers in a Sequential model ' 'should have a single output tensor. ' 'For multi-output layers, ' 'use the functional API.') - self.outputs = [layer._inbound_nodes[0].output_tensors[0]] + self.outputs = [layer.inbound_nodes[-1].output_tensors[0]] self.inputs = topology.get_source_inputs(self.outputs[0]) # We create an input node, which we will keep updated @@ -716,24 +727,42 @@ class Sequential(Model): metrics=None, sample_weight_mode=None, weighted_metrics=None, + target_tensors=None, **kwargs): - """Configures the learning process. + """Configures the model for training. Arguments: - optimizer: str (name of optimizer) or optimizer object. + optimizer: String (name of optimizer) or optimizer object. See [optimizers](/optimizers). - loss: str (name of objective function) or objective function. + loss: String (name of objective function) or objective function. See [losses](/losses). - metrics: list of metrics to be evaluated by the model + If the model has multiple outputs, you can use a different loss + on each output by passing a dictionary or a list of losses. + The loss value that will be minimized by the model + will then be the sum of all individual losses. + metrics: List of metrics to be evaluated by the model during training and testing. Typically you will use `metrics=['accuracy']`. - See [metrics](/metrics). - sample_weight_mode: if you need to do timestep-wise - sample weighting (2D weights), set this to "temporal". - "None" defaults to sample-wise weights (1D). + To specify different metrics for different outputs of a + multi-output model, you could also pass a dictionary, + such as `metrics={'output_a': 'accuracy'}`. + sample_weight_mode: If you need to do timestep-wise + sample weighting (2D weights), set this to `"temporal"`. + `None` defaults to sample-wise weights (1D). + If the model has multiple outputs, you can use a different + `sample_weight_mode` on each output by passing a + dictionary or a list of modes. weighted_metrics: list of metrics to be evaluated and weighted by `sample_weight` or `class_weight` during training and testing. - **kwargs: These are passed into `tf.Session.run`. + target_tensors: By default, Keras will create a placeholder for the + model's target, which will be fed with the target data during + training. If instead you would like to use your own + target tensor (in turn, Keras will not expect external + Numpy data for these targets at training time), you + can specify them via the `target_tensors` argument. + It should be a single tensor + (for a single-output `Sequential` model). + **kwargs: These arguments are passed into `tf.Session.run`. Example: ```python @@ -754,24 +783,25 @@ class Sequential(Model): metrics=metrics, sample_weight_mode=sample_weight_mode, weighted_metrics=weighted_metrics, + target_tensors=target_tensors, **kwargs) self.optimizer = self.model.optimizer self.loss = self.model.loss - self.total_loss = self.model.total_loss - self.loss_weights = self.model.loss_weights self.metrics = self.model.metrics + self.loss_weights = self.model.loss_weights + self.sample_weight_mode = self.model.sample_weight_mode self.weighted_metrics = self.model.weighted_metrics + self.targets = self.model.targets self.metrics_tensors = self.model.metrics_tensors self.metrics_names = self.model.metrics_names - self.sample_weight_mode = self.model.sample_weight_mode self.sample_weights = self.model.sample_weights - self.targets = self.model.targets + self.total_loss = self.model.total_loss def fit(self, - x, - y, - batch_size=32, - epochs=10, + x=None, + y=None, + batch_size=None, + epochs=1, verbose=1, callbacks=None, validation_split=0., @@ -779,43 +809,86 @@ class Sequential(Model): shuffle=True, class_weight=None, sample_weight=None, - initial_epoch=0): + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + **kwargs): """Trains the model for a fixed number of epochs. Arguments: - x: input data, as a Numpy array or list of Numpy arrays - (if the model has multiple inputs). - y: labels, as a Numpy array. - batch_size: integer. Number of samples per gradient update. - epochs: integer, the number of epochs to train the model. - verbose: 0 for no logging to stdout, - 1 for progress bar logging, 2 for one log line per epoch. - callbacks: list of `keras.callbacks.Callback` instances. + x: Numpy array of training data. + If the input layer in the model is named, you can also pass a + dictionary mapping the input name to a Numpy array. + `x` can be `None` (default) if feeding from + TensorFlow data tensors. + y: Numpy array of target (label) data. + If the output layer in the model is named, you can also pass a + dictionary mapping the output name to a Numpy array. + `y` can be `None` (default) if feeding from + TensorFlow data tensors. + batch_size: Integer or `None`. + Number of samples per gradient update. + If unspecified, it will default to 32. + epochs: Integer. Number of epochs to train the model. + An epoch is an iteration over the entire `x` and `y` + data provided. + Note that in conjunction with `initial_epoch`, + `epochs` is to be understood as "final epoch". + The model is not trained for a number of iterations + given by `epochs`, but merely until the epoch + of index `epochs` is reached. + verbose: 0, 1, or 2. Verbosity mode. + 0 = silent, 1 = progress bar, 2 = one line per epoch. + callbacks: List of `keras.callbacks.Callback` instances. List of callbacks to apply during training. See [callbacks](/callbacks). - validation_split: float (0. < x < 1). - Fraction of the data to use as held-out validation data. - validation_data: tuple (x_val, y_val) or tuple - (x_val, y_val, val_sample_weights) to be used as held-out - validation data. Will override validation_split. - shuffle: boolean or str (for 'batch'). - Whether to shuffle the samples at each epoch. + validation_split: Float between 0 and 1: + Fraction of the training data to be used as validation data. + The model will set apart this fraction of the training data, + will not train on it, and will evaluate + the loss and any model metrics + on this data at the end of each epoch. + The validation data is selected from the last samples + in the `x` and `y` data provided, before shuffling. + validation_data: tuple `(x_val, y_val)` or tuple + `(x_val, y_val, val_sample_weights)` on which to evaluate + the loss and any model metrics at the end of each epoch. + The model will not be trained on this data. + This will override `validation_split`. + shuffle: Boolean (whether to shuffle the training data + before each epoch) or str (for 'batch'). 'batch' is a special option for dealing with the limitations of HDF5 data; it shuffles in batch-sized chunks. - class_weight: dictionary mapping classes to a weight value, - used for scaling the loss function (during training only). - sample_weight: Numpy array of weights for - the training samples, used for scaling the loss function + Has no effect when `steps_per_epoch` is not `None`. + class_weight: Optional dictionary mapping class indices (integers) + to a weight (float) value, used for weighting the loss function + (during training only). + This can be useful to tell the model to + "pay more attention" to samples from + an under-represented class. + sample_weight: Optional Numpy array of weights for + the training samples, used for weighting the loss function (during training only). You can either pass a flat (1D) Numpy array with the same length as the input samples (1:1 mapping between weights and samples), or in the case of temporal data, - you can pass a 2D array with shape (samples, sequence_length), + you can pass a 2D array with shape + `(samples, sequence_length)`, to apply a different weight to every timestep of every sample. In this case you should make sure to specify - sample_weight_mode="temporal" in compile(). - initial_epoch: epoch at which to start training - (useful for resuming a previous training run) + `sample_weight_mode="temporal"` in `compile()`. + initial_epoch: Epoch at which to start training + (useful for resuming a previous training run). + steps_per_epoch: Total number of steps (batches of samples) + before declaring one epoch finished and starting the + next epoch. When training with input tensors such as + TensorFlow data tensors, the default `None` is equal to + the number of unique samples in your dataset divided by + the batch size, or 1 if that cannot be determined. + validation_steps: Only relevant if `steps_per_epoch` + is specified. Total number of steps (batches of samples) + to validate before stopping. + **kwargs: Used for backwards compatibility support. Returns: A `History` object. Its `History.history` attribute is @@ -824,10 +897,12 @@ class Sequential(Model): and validation metrics values (if applicable). Raises: - RuntimeError: if the model was never compiled. + RuntimeError: If the model was never compiled. + ValueError: In case of mismatch between the provided input data + and what the model expects. """ if not self.built: - raise RuntimeError('The model needs to be compiled ' 'before being used.') + raise RuntimeError('The model needs to be compiled before being used.') return self.model.fit( x, y, @@ -840,7 +915,9 @@ class Sequential(Model): shuffle=shuffle, class_weight=class_weight, sample_weight=sample_weight, - initial_epoch=initial_epoch) + initial_epoch=initial_epoch, + steps_per_epoch=steps_per_epoch, + validation_steps=validation_steps) def evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None): """Computes the loss on some input data, batch by batch. @@ -863,7 +940,7 @@ class Sequential(Model): RuntimeError: if the model was never compiled. """ if not self.built: - raise RuntimeError('The model needs to be compiled ' 'before being used.') + raise RuntimeError('The model needs to be compiled before being used.') return self.model.evaluate( x, y, @@ -923,7 +1000,7 @@ class Sequential(Model): RuntimeError: if the model was never compiled. """ if not self.built: - raise RuntimeError('The model needs to be compiled ' 'before being used.') + raise RuntimeError('The model needs to be compiled before being used.') return self.model.train_on_batch( x, y, sample_weight=sample_weight, class_weight=class_weight) @@ -946,10 +1023,10 @@ class Sequential(Model): RuntimeError: if the model was never compiled. """ if not self.built: - raise RuntimeError('The model needs to be compiled ' 'before being used.') + raise RuntimeError('The model needs to be compiled before being used.') return self.model.test_on_batch(x, y, sample_weight=sample_weight) - def predict_proba(self, x, batch_size=32, verbose=1): + def predict_proba(self, x, batch_size=32, verbose=0): """Generates class probability predictions for the input samples. The input samples are processed batch by batch. @@ -971,7 +1048,7 @@ class Sequential(Model): '(like softmax or sigmoid would).') return preds - def predict_classes(self, x, batch_size=32, verbose=1): + def predict_classes(self, x, batch_size=32, verbose=0): """Generate class predictions for the input samples. The input samples are processed batch by batch. @@ -1003,6 +1080,7 @@ class Sequential(Model): max_queue_size=10, workers=1, use_multiprocessing=False, + shuffle=True, initial_epoch=0, **kwargs): """Fits the model on data generated batch-by-batch by a Python generator. @@ -1026,6 +1104,10 @@ class Sequential(Model): be equal to the number of unique samples of your dataset divided by the batch size. epochs: Integer, total number of iterations on the data. + Note that in conjunction with initial_epoch, the parameter + epochs is to be understood as "final epoch". The model is + not trained for n steps given by epochs, but until the + epoch epochs is reached. verbose: Verbosity mode, 0, 1, or 2. callbacks: List of callbacks to be called during training. validation_data: This can be either @@ -1049,6 +1131,9 @@ class Sequential(Model): non picklable arguments to the generator as they can't be passed easily to children processes. + shuffle: Whether to shuffle the order of the batches at + the beginning of each epoch. Only used with instances + of `Sequence` (keras.utils.Sequence). initial_epoch: Epoch at which to start training (useful for resuming a previous training run) **kwargs: support for legacy arguments. @@ -1092,7 +1177,7 @@ class Sequential(Model): raise ValueError('Unrecognized keyword arguments: ' + str(kwargs)) if not self.built: - raise RuntimeError('The model needs to be compiled ' 'before being used.') + raise RuntimeError('The model needs to be compiled before being used.') return self.model.fit_generator( generator, steps_per_epoch, @@ -1105,6 +1190,7 @@ class Sequential(Model): max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, + shuffle=shuffle, initial_epoch=initial_epoch) def evaluate_generator(self, @@ -1158,7 +1244,7 @@ class Sequential(Model): raise ValueError('Unrecognized keyword arguments: ' + str(kwargs)) if not self.built: - raise RuntimeError('The model needs to be compiled ' 'before being used.') + raise RuntimeError('The model needs to be compiled before being used.') return self.model.evaluate_generator( generator, steps, diff --git a/tensorflow/python/keras/_impl/keras/models_test.py b/tensorflow/python/keras/_impl/keras/models_test.py index fd6b20e0edc..86acac4604a 100644 --- a/tensorflow/python/keras/_impl/keras/models_test.py +++ b/tensorflow/python/keras/_impl/keras/models_test.py @@ -315,6 +315,24 @@ class TestSequential(test.TestCase): with self.assertRaises(TypeError): model.build() + def test_nested_sequential_trainability(self): + input_dim = 20 + num_units = 10 + num_classes = 2 + + inner_model = keras.models.Sequential() + inner_model.add(keras.layers.Dense(num_units, input_shape=(input_dim,))) + + model = keras.models.Sequential() + model.add(inner_model) + model.add(keras.layers.Dense(num_classes)) + + self.assertEqual(len(model.trainable_weights), 4) + inner_model.trainable = False + self.assertEqual(len(model.trainable_weights), 2) + inner_model.trainable = True + self.assertEqual(len(model.trainable_weights), 4) + class TestModelCloning(test.TestCase): diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/image.py b/tensorflow/python/keras/_impl/keras/preprocessing/image.py index 052a8addc4c..12dc718cd79 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/image.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/image.py @@ -31,6 +31,7 @@ import numpy as np from six.moves import range # pylint: disable=redefined-builtin from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence from tensorflow.python.platform import tf_logging as logging @@ -47,6 +48,21 @@ except ImportError: ndi = None # pylint: enable=g-import-not-at-top +if pil_image is not None: + _PIL_INTERPOLATION_METHODS = { + 'nearest': pil_image.NEAREST, + 'bilinear': pil_image.BILINEAR, + 'bicubic': pil_image.BICUBIC, + } + # These methods were only introduced in version 3.4.0 (2016). + if hasattr(pil_image, 'HAMMING'): + _PIL_INTERPOLATION_METHODS['hamming'] = pil_image.HAMMING + if hasattr(pil_image, 'BOX'): + _PIL_INTERPOLATION_METHODS['box'] = pil_image.BOX + # This method is new in version 1.1.3 (2013). + if hasattr(pil_image, 'LANCZOS'): + _PIL_INTERPOLATION_METHODS['lanczos'] = pil_image.LANCZOS + def random_rotation(x, rg, @@ -172,10 +188,8 @@ def random_zoom(x, (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). cval: Value used for points outside the boundaries of the input if `mode='constant'`. - Returns: Zoomed Numpy image tensor. - Raises: ValueError: if `zoom_range` isn't a tuple. """ @@ -344,7 +358,7 @@ def img_to_array(img, data_format=None): return x -def load_img(path, grayscale=False, target_size=None): +def load_img(path, grayscale=False, target_size=None, interpolation='nearest'): """Loads an image into PIL format. Arguments: @@ -352,12 +366,19 @@ def load_img(path, grayscale=False, target_size=None): grayscale: Boolean, whether to load the image as grayscale. target_size: Either `None` (default to original size) or tuple of ints `(img_height, img_width)`. + interpolation: Interpolation method used to resample the image if the + target size is different from that of the loaded image. + Supported methods are "nearest", "bilinear", and "bicubic". + If PIL version 1.1.3 or newer is installed, "lanczos" is also + supported. If PIL version 3.4.0 or newer is installed, "box" and + "hamming" are also supported. By default, "nearest" is used. Returns: A PIL Image instance. Raises: ImportError: if PIL is not available. + ValueError: if interpolation method is not supported. """ if pil_image is None: raise ImportError('Could not import PIL.Image. ' @@ -369,14 +390,21 @@ def load_img(path, grayscale=False, target_size=None): else: if img.mode != 'RGB': img = img.convert('RGB') - if target_size: - hw_tuple = (target_size[1], target_size[0]) - if img.size != hw_tuple: - img = img.resize(hw_tuple) + if target_size is not None: + width_height_tuple = (target_size[1], target_size[0]) + if img.size != width_height_tuple: + if interpolation not in _PIL_INTERPOLATION_METHODS: + raise ValueError( + 'Invalid interpolation method {} specified. Supported ' + 'methods are {}'.format( + interpolation, + ', '.join(_PIL_INTERPOLATION_METHODS.keys()))) + resample = _PIL_INTERPOLATION_METHODS[interpolation] + img = img.resize(width_height_tuple, resample) return img -def list_pictures(directory, ext='jpg|jpeg|bmp|png'): +def list_pictures(directory, ext='jpg|jpeg|bmp|png|ppm'): return [ os.path.join(root, f) for root, _, files in os.walk(directory) for f in files @@ -401,7 +429,7 @@ class ImageDataGenerator(object): zoom_range: amount of zoom. if scalar z, zoom will be randomly picked in the range [1-z, 1+z]. A sequence of two can be passed instead to select this range. - channel_shift_range: shift range for each channels. + channel_shift_range: shift range for each channel. fill_mode: points outside the boundaries are filled according to the given mode ('constant', 'nearest', 'reflect' or 'wrap'). Default is 'nearest'. @@ -558,12 +586,10 @@ class ImageDataGenerator(object): x = self.preprocessing_function(x) if self.rescale: x *= self.rescale - # x is a single image, so it doesn't have image number at index 0 - img_channel_axis = self.channel_axis - 1 if self.samplewise_center: - x -= np.mean(x, axis=img_channel_axis, keepdims=True) + x -= np.mean(x, keepdims=True) if self.samplewise_std_normalization: - x /= (np.std(x, axis=img_channel_axis, keepdims=True) + 1e-7) + x /= np.std(x, keepdims=True) + 1e-7 if self.featurewise_center: if self.mean is not None: @@ -762,49 +788,76 @@ class ImageDataGenerator(object): np.dot(u, np.diag(1. / np.sqrt(s + self.zca_epsilon))), u.T) -class Iterator(object): - """Abstract base class for image data iterators. +class Iterator(Sequence): + """Base class for image data iterators. + + Every `Iterator` must implement the `_get_batches_of_transformed_samples` + method. Arguments: - n: Integer, total number of samples in the dataset to loop over. - batch_size: Integer, size of a batch. - shuffle: Boolean, whether to shuffle the data between epochs. - seed: Random seeding for data shuffling. + n: Integer, total number of samples in the dataset to loop over. + batch_size: Integer, size of a batch. + shuffle: Boolean, whether to shuffle the data between epochs. + seed: Random seeding for data shuffling. """ def __init__(self, n, batch_size, shuffle, seed): self.n = n self.batch_size = batch_size + self.seed = seed self.shuffle = shuffle self.batch_index = 0 self.total_batches_seen = 0 self.lock = threading.Lock() - self.index_generator = self._flow_index(n, batch_size, shuffle, seed) + self.index_array = None + self.index_generator = self._flow_index() + + def _set_index_array(self): + self.index_array = np.arange(self.n) + if self.shuffle: + self.index_array = np.random.permutation(self.n) + + def __getitem__(self, idx): + if idx >= len(self): + raise ValueError('Asked to retrieve element {idx}, ' + 'but the Sequence ' + 'has length {length}'.format(idx=idx, + length=len(self))) + if self.seed is not None: + np.random.seed(self.seed + self.total_batches_seen) + self.total_batches_seen += 1 + if self.index_array is None: + self._set_index_array() + index_array = self.index_array[self.batch_size * idx:self.batch_size * + (idx + 1)] + return self._get_batches_of_transformed_samples(index_array) + + def __len__(self): + length = int(np.ceil(self.n / float(self.batch_size))) + return np.maximum(length, 0) + + def on_epoch_end(self): + self._set_index_array() def reset(self): self.batch_index = 0 - def _flow_index(self, n, batch_size=32, shuffle=False, seed=None): + def _flow_index(self): # Ensure self.batch_index is 0. self.reset() while 1: - if seed is not None: - np.random.seed(seed + self.total_batches_seen) + if self.seed is not None: + np.random.seed(self.seed + self.total_batches_seen) if self.batch_index == 0: - index_array = np.arange(n) - if shuffle: - index_array = np.random.permutation(n) + self._set_index_array() - current_index = (self.batch_index * batch_size) % n - if n > current_index + batch_size: - current_batch_size = batch_size + current_index = (self.batch_index * self.batch_size) % self.n + if self.n > current_index + self.batch_size: self.batch_index += 1 else: - current_batch_size = n - current_index self.batch_index = 0 self.total_batches_seen += 1 - yield (index_array[current_index:current_index + current_batch_size], - current_index, current_batch_size) + yield self.index_array[current_index:current_index + self.batch_size] def __iter__(self): # pylint: disable=non-iterator-returned # Needed if we want to do something like: @@ -814,6 +867,16 @@ class Iterator(object): def __next__(self, *args, **kwargs): return self.next(*args, **kwargs) + def _get_batches_of_transformed_samples(self, index_array): + """Gets a batch of transformed samples. + + Arguments: + index_array: array of sample indices to include in batch. + Returns: + A batch of transformed samples. + """ + raise NotImplementedError + class NumpyArrayIterator(Iterator): """Iterator yielding data from a Numpy array. @@ -883,6 +946,26 @@ class NumpyArrayIterator(Iterator): super(NumpyArrayIterator, self).__init__(x.shape[0], batch_size, shuffle, seed) + def _get_batches_of_transformed_samples(self, index_array): + batch_x = np.zeros(tuple([len(index_array)] + list(self.x.shape)[1:]), + dtype=K.floatx()) + for i, j in enumerate(index_array): + x = self.x[j] + x = self.image_data_generator.random_transform(x.astype(K.floatx())) + x = self.image_data_generator.standardize(x) + batch_x[i] = x + if self.save_to_dir: + for i, j in enumerate(index_array): + img = array_to_img(batch_x[i], self.data_format, scale=True) + fname = '{prefix}_{index}_{hash}.{format}'.format( + prefix=self.save_prefix, index=j, hash=np.random.randint(1e4), + format=self.save_format) + img.save(os.path.join(self.save_to_dir, fname)) + if self.y is None: + return batch_x + batch_y = self.y[index_array] + return batch_x, batch_y + def next(self): """For python 2.x. @@ -892,30 +975,10 @@ class NumpyArrayIterator(Iterator): # Keeps under lock only the mechanism which advances # the indexing of each batch. with self.lock: - index_array, current_index, current_batch_size = next( - self.index_generator) + index_array = next(self.index_generator) # The transformation of images is not under thread lock # so it can be done in parallel - batch_x = np.zeros( - tuple([current_batch_size] + list(self.x.shape)[1:]), dtype=K.floatx()) - for i, j in enumerate(index_array): - x = self.x[j] - x = self.image_data_generator.random_transform(x.astype(K.floatx())) - x = self.image_data_generator.standardize(x) - batch_x[i] = x - if self.save_to_dir: - for i in range(current_batch_size): - img = array_to_img(batch_x[i], self.data_format, scale=True) - fname = '{prefix}_{index}_{hash}.{format}'.format( - prefix=self.save_prefix, - index=current_index + i, - hash=np.random.randint(1e4), - format=self.save_format) - img.save(os.path.join(self.save_to_dir, fname)) - if self.y is None: - return batch_x - batch_y = self.y[index_array] - return batch_x, batch_y + return self._get_batches_of_transformed_samples(index_array) def _count_valid_files_in_directory(directory, white_list_formats, @@ -939,7 +1002,7 @@ def _count_valid_files_in_directory(directory, white_list_formats, samples = 0 for _, _, files in _recursive_list(directory): - for fname in files: + for fname in sorted(files): is_valid = False for extension in white_list_formats: if fname.lower().endswith('.' + extension): @@ -1006,7 +1069,7 @@ class DirectoryIterator(Iterator): to use for random transformations and normalization. target_size: tuple of integers, dimensions to resize input images to. color_mode: One of `"rgb"`, `"grayscale"`. Color mode to read images. - classes: Optional list of strings, names of sudirectories + classes: Optional list of strings, names of subdirectories containing images from each class (e.g. `["dogs", "cats"]`). It will be computed automatically if not set. class_mode: Mode for yielding the targets: @@ -1086,7 +1149,7 @@ class DirectoryIterator(Iterator): for subdir in sorted(os.listdir(directory)): if os.path.isdir(os.path.join(directory, subdir)): classes.append(subdir) - self.num_class = len(classes) + self.num_classes = len(classes) self.class_indices = dict(zip(classes, range(len(classes)))) pool = multiprocessing.pool.ThreadPool() @@ -1099,7 +1162,7 @@ class DirectoryIterator(Iterator): for subdir in classes))) print('Found %d images belonging to %d classes.' % (self.samples, - self.num_class)) + self.num_classes)) # second, build an index of the images in the different class subfolders results = [] @@ -1121,39 +1184,25 @@ class DirectoryIterator(Iterator): super(DirectoryIterator, self).__init__(self.samples, batch_size, shuffle, seed) - def next(self): - """For python 2.x. - - Returns: - The next batch. - """ - with self.lock: - index_array, current_index, current_batch_size = next( - self.index_generator) - # The transformation of images is not under thread lock - # so it can be done in parallel - batch_x = np.zeros( - (current_batch_size,) + self.image_shape, dtype=K.floatx()) + def _get_batches_of_transformed_samples(self, index_array): + batch_x = np.zeros((len(index_array),) + self.image_shape, dtype=K.floatx()) grayscale = self.color_mode == 'grayscale' # build batch of image data for i, j in enumerate(index_array): fname = self.filenames[j] - img = load_img( - os.path.join(self.directory, fname), - grayscale=grayscale, - target_size=self.target_size) + img = load_img(os.path.join(self.directory, fname), + grayscale=grayscale, + target_size=self.target_size) x = img_to_array(img, data_format=self.data_format) x = self.image_data_generator.random_transform(x) x = self.image_data_generator.standardize(x) batch_x[i] = x # optionally save augmented images to disk for debugging purposes if self.save_to_dir: - for i in range(current_batch_size): + for i, j in enumerate(index_array): img = array_to_img(batch_x[i], self.data_format, scale=True) fname = '{prefix}_{index}_{hash}.{format}'.format( - prefix=self.save_prefix, - index=current_index + i, - hash=np.random.randint(1e4), + prefix=self.save_prefix, index=j, hash=np.random.randint(1e7), format=self.save_format) img.save(os.path.join(self.save_to_dir, fname)) # build batch of labels @@ -1164,9 +1213,22 @@ class DirectoryIterator(Iterator): elif self.class_mode == 'binary': batch_y = self.classes[index_array].astype(K.floatx()) elif self.class_mode == 'categorical': - batch_y = np.zeros((len(batch_x), self.num_class), dtype=K.floatx()) + batch_y = np.zeros((len(batch_x), self.num_classes), dtype=K.floatx()) for i, label in enumerate(self.classes[index_array]): batch_y[i, label] = 1. else: return batch_x return batch_x, batch_y + + def next(self): + """For python 2.x. + + Returns: + The next batch. + """ + with self.lock: + index_array = next(self.index_generator) + # The transformation of images is not under thread lock + # so it can be done in parallel + return self._get_batches_of_transformed_samples(index_array) + diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py b/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py index 19693410e76..c0790b5a514 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py @@ -192,6 +192,8 @@ class TestImage(test.TestCase): _ = keras.preprocessing.image.load_img(fname) _ = keras.preprocessing.image.load_img(fname, grayscale=True) _ = keras.preprocessing.image.load_img(fname, target_size=(10, 10)) + _ = keras.preprocessing.image.load_img(fname, target_size=(10, 10), + interpolation='bilinear') # create iterator generator = keras.preprocessing.image.ImageDataGenerator() diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py b/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py index a5deec87af7..642f4f2face 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py @@ -169,7 +169,7 @@ def skipgrams(sequence, integers (eg. [0, 1, 1 .. ]), if True labels will be categorical eg. [[1,0],[0,1],[0,1] .. ] sampling_table: 1D array of size `vocabulary_size` where the entry i - encodes the probabibily to sample a word of rank i. + encodes the probability to sample a word of rank i. seed: Random seed. Returns: diff --git a/tensorflow/python/keras/_impl/keras/utils/__init__.py b/tensorflow/python/keras/_impl/keras/utils/__init__.py index 78f325cf619..370ae0dd0f0 100644 --- a/tensorflow/python/keras/_impl/keras/utils/__init__.py +++ b/tensorflow/python/keras/_impl/keras/utils/__init__.py @@ -30,6 +30,7 @@ from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.keras._impl.keras.utils.io_utils import HDF5Matrix from tensorflow.python.keras._impl.keras.utils.layer_utils import convert_all_kernels_in_model +from tensorflow.python.keras._impl.keras.utils.layer_utils import print_summary from tensorflow.python.keras._impl.keras.utils.np_utils import normalize from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical from tensorflow.python.keras._impl.keras.utils.training_utils import multi_gpu_model diff --git a/tensorflow/python/keras/_impl/keras/utils/data_utils.py b/tensorflow/python/keras/_impl/keras/utils/data_utils.py index 0ede7f12f2c..1f2e9ac4407 100644 --- a/tensorflow/python/keras/_impl/keras/utils/data_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/data_utils.py @@ -70,15 +70,15 @@ if sys.version_info[0] == 2: if content_type is not None: total_size = int(content_type.strip()) count = 0 - while 1: + while True: chunk = response.read(chunk_size) count += 1 - if not chunk: - reporthook(count, total_size, total_size) - break - if reporthook: + if reporthook is not None: reporthook(count, chunk_size, total_size) - yield chunk + if chunk: + yield chunk + else: + break response = urlopen(url, data) with open(filename, 'wb') as fd: @@ -262,9 +262,9 @@ def _hash_file(fpath, algorithm='sha256', chunk_size=65535): Example: ```python - >>> from keras.data_utils import _hash_file - >>> _hash_file('/path/to/file.zip') - 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' + >>> from keras.data_utils import _hash_file + >>> _hash_file('/path/to/file.zip') + 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' ``` Arguments: @@ -318,32 +318,35 @@ class Sequence(object): """Base object for fitting to a sequence of data, such as a dataset. Every `Sequence` must implements the `__getitem__` and the `__len__` methods. + If you want to modify your dataset between epochs you may implement + `on_epoch_end`. The method `__getitem__` should return a complete batch. + Notes: + `Sequence` are a safer way to do multiprocessing. This structure guarantees + that the network will only train once on each sample per epoch which is not + the case with generators. Examples: - ```python - from skimage.io import imread - from skimage.transform import resize - import numpy as np - - # Here, `x_set` is list of path to the images - # and `y_set` are the associated classes. - - class CIFAR10Sequence(Sequence): - def __init__(self, x_set, y_set, batch_size): - self.X,self.y = x_set,y_set - self.batch_size = batch_size - - def __len__(self): - return len(self.X) // self.batch_size - - def __getitem__(self,idx): - batch_x = self.X[idx*self.batch_size:(idx+1)*self.batch_size] - batch_y = self.y[idx*self.batch_size:(idx+1)*self.batch_size] - - return np.array([ - resize(imread(file_name), (200,200)) - for file_name in batch_x]), np.array(batch_y) + from skimage.io import imread + from skimage.transform import resize + import numpy as np + import math + # Here, `x_set` is list of path to the images + # and `y_set` are the associated classes. + class CIFAR10Sequence(Sequence): + def __init__(self, x_set, y_set, batch_size): + self.x, self.y = x_set, y_set + self.batch_size = batch_size + def __len__(self): + return math.ceil(len(self.x) / self.batch_size) + def __getitem__(self, idx): + batch_x = self.x[idx * self.batch_size:(idx + 1) * + self.batch_size] + batch_y = self.y[idx * self.batch_size:(idx + 1) * + self.batch_size] + return np.array([ + resize(imread(file_name), (200, 200)) + for file_name in batch_x]), np.array(batch_y) ``` """ @@ -372,20 +375,30 @@ class Sequence(object): def on_epoch_end(self): """Method called at the end of every epoch. """ - raise NotImplementedError + pass -def get_index(ds, i): - """Quick fix for Python2, otherwise, it cannot be pickled. +# Global variables to be shared across processes +_SHARED_SEQUENCES = {} +# We use a Value to provide unique id to different processes. +_SEQUENCE_COUNTER = None + + +def get_index(uid, i): + """Get the value from the Sequence `uid` at index `i`. + + To allow multiple Sequences to be used at the same time, we use `uid` to + get a specific one. A single Sequence would cause the validation to + overwrite the training Sequence. Arguments: - ds: a Holder or Sequence object. + uid: int, Sequence identifier i: index Returns: The value at index `i`. """ - return ds[i] + return _SHARED_SEQUENCES[uid][i] class SequenceEnqueuer(object): @@ -397,13 +410,13 @@ class SequenceEnqueuer(object): Examples: ```python - enqueuer = SequenceEnqueuer(...) - enqueuer.start() - datas = enqueuer.get() - for data in datas: - # Use the inputs; training, evaluating, predicting. - # ... stop sometime. - enqueuer.close() + enqueuer = SequenceEnqueuer(...) + enqueuer.start() + datas = enqueuer.get() + for data in datas: + # Use the inputs; training, evaluating, predicting. + # ... stop sometime. + enqueuer.close() ``` The `enqueuer.get()` should be an infinite stream of datas. @@ -456,17 +469,21 @@ class OrderedEnqueuer(SequenceEnqueuer): Arguments: sequence: A `keras.utils.data_utils.Sequence` object. - use_multiprocessing: use multiprocessing if True, otherwise threading - scheduling: Sequential querying of datas if 'sequential', random - otherwise. - shuffle: Whether to shuffle the data at the beginning of each epoch. + use_multiprocessing: Use multiprocessing if True, otherwise threading + shuffle: Whether to shuffle the data at the beginning of each epoch """ - def __init__(self, - sequence, - use_multiprocessing=False, - shuffle=False): + def __init__(self, sequence, use_multiprocessing=False, shuffle=False): self.sequence = sequence + + # Doing Multiprocessing.Value += x is not process-safe. + global _SEQUENCE_COUNTER + if _SEQUENCE_COUNTER is None: + _SEQUENCE_COUNTER = multiprocessing.Value('i', 0) + + with _SEQUENCE_COUNTER.get_lock(): + self.uid = _SEQUENCE_COUNTER.value + _SEQUENCE_COUNTER.value += 1 self.use_multiprocessing = use_multiprocessing self.shuffle = shuffle self.workers = 0 @@ -490,15 +507,24 @@ class OrderedEnqueuer(SequenceEnqueuer): self.executor = multiprocessing.Pool(workers) else: self.executor = ThreadPool(workers) + self.workers = workers self.queue = queue.Queue(max_queue_size) self.stop_signal = threading.Event() self.run_thread = threading.Thread(target=self._run) self.run_thread.daemon = True self.run_thread.start() + def _wait_queue(self): + """Wait for the queue to be empty.""" + while True: + time.sleep(0.1) + if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set(): + return + def _run(self): - """Submits requests to the executor and queues the `Future` objects.""" + """Function to submit request to the executor & queue `Future` objects.""" sequence = list(range(len(self.sequence))) + self._send_sequence() # Share the initial sequence while True: if self.shuffle: random.shuffle(sequence) @@ -506,9 +532,18 @@ class OrderedEnqueuer(SequenceEnqueuer): if self.stop_signal.is_set(): return self.queue.put( - self.executor.apply_async(get_index, (self.sequence, i)), - block=True) + self.executor.apply_async(get_index, (self.uid, i)), block=True) + + # Done with the current epoch, waiting for the final batches + self._wait_queue() + + if self.stop_signal.is_set(): + # We're done + return + + # Call the internal on epoch end. self.sequence.on_epoch_end() + self._send_sequence() # Update the pool def get(self): """Creates a generator to extract data from the queue. @@ -517,17 +552,29 @@ class OrderedEnqueuer(SequenceEnqueuer): Yields: Tuples (inputs, targets) - or (inputs, targets, sample_weights) + or (inputs, targets, sample_weights) """ try: while self.is_running(): inputs = self.queue.get(block=True).get() + self.queue.task_done() if inputs is not None: yield inputs except Exception as e: self.stop() raise StopIteration(e) + def _send_sequence(self): + """Send current Sequence to all workers.""" + _SHARED_SEQUENCES[ + self.uid] = self.sequence # For new processes that may spawn + + self._close_pool() + if self.use_multiprocessing: + self.executor = multiprocessing.Pool(self.workers) + else: + self.executor = ThreadPool(self.workers) + def stop(self, timeout=None): """Stops running threads and wait for them to exit, if necessary. @@ -541,36 +588,43 @@ class OrderedEnqueuer(SequenceEnqueuer): self.queue.queue.clear() self.queue.unfinished_tasks = 0 self.queue.not_full.notify() + self._close_pool() + self.run_thread.join(timeout) + _SHARED_SEQUENCES[self.uid] = None + + def _close_pool(self): self.executor.close() self.executor.join() - self.run_thread.join(timeout) class GeneratorEnqueuer(SequenceEnqueuer): """Builds a queue out of a data generator. + The provided generator can be finite in which case the class will throw + a `StopIteration` exception. + Used in `fit_generator`, `evaluate_generator`, `predict_generator`. Arguments: - generator: a generator function which endlessly yields data + generator: a generator function which yields data use_multiprocessing: use multiprocessing if True, otherwise threading wait_time: time to sleep in-between calls to `put()` random_seed: Initial seed for workers, - will be incremented by one for each workers. + will be incremented by one for each worker. """ def __init__(self, generator, use_multiprocessing=False, wait_time=0.05, - random_seed=None): + seed=None): self.wait_time = wait_time self._generator = generator self._use_multiprocessing = use_multiprocessing self._threads = [] self._stop_event = None self.queue = None - self.random_seed = random_seed + self.seed = seed def start(self, workers=1, max_queue_size=10): """Kicks off threads which add data from the generator into the queue. @@ -589,6 +643,8 @@ class GeneratorEnqueuer(SequenceEnqueuer): self.queue.put(generator_output) else: time.sleep(self.wait_time) + except StopIteration: + break except Exception: self._stop_event.set() raise @@ -605,11 +661,11 @@ class GeneratorEnqueuer(SequenceEnqueuer): if self._use_multiprocessing: # Reset random seed else all children processes # share the same seed - np.random.seed(self.random_seed) + np.random.seed(self.seed) thread = multiprocessing.Process(target=data_generator_task) thread.daemon = True - if self.random_seed is not None: - self.random_seed += 1 + if self.seed is not None: + self.seed += 1 else: thread = threading.Thread(target=data_generator_task) self._threads.append(thread) @@ -661,4 +717,8 @@ class GeneratorEnqueuer(SequenceEnqueuer): if inputs is not None: yield inputs else: - time.sleep(self.wait_time) + all_finished = all([not thread.is_alive() for thread in self._threads]) + if all_finished and self.queue.empty(): + raise StopIteration() + else: + time.sleep(self.wait_time) diff --git a/tensorflow/python/keras/_impl/keras/utils/data_utils_test.py b/tensorflow/python/keras/_impl/keras/utils/data_utils_test.py index 45322f1f29c..14b2f084423 100644 --- a/tensorflow/python/keras/_impl/keras/utils/data_utils_test.py +++ b/tensorflow/python/keras/_impl/keras/utils/data_utils_test.py @@ -115,15 +115,19 @@ def threadsafe_generator(f): class TestSequence(keras.utils.data_utils.Sequence): - def __init__(self, shape): + def __init__(self, shape, value=1.): self.shape = shape + self.inner = value def __getitem__(self, item): - return np.ones(self.shape, dtype=np.uint8) * item + return np.ones(self.shape, dtype=np.uint32) * item * self.inner def __len__(self): return 100 + def on_epoch_end(self): + self.inner *= 5.0 + class FaultSequence(keras.utils.data_utils.Sequence): @@ -228,6 +232,64 @@ class TestEnqueuers(test.TestCase): with self.assertRaises(StopIteration): next(gen_output) + def test_on_epoch_end_processes(self): + enqueuer = keras.utils.data_utils.OrderedEnqueuer( + TestSequence([3, 200, 200, 3]), use_multiprocessing=True) + enqueuer.start(3, 10) + gen_output = enqueuer.get() + acc = [] + for _ in range(200): + acc.append(next(gen_output)[0, 0, 0, 0]) + # Check that order was keep in GeneratorEnqueuer with processes + self.assertEqual(acc[100:], list([k * 5 for k in range(100)])) + enqueuer.stop() + + def test_context_switch(self): + enqueuer = keras.utils.data_utils.OrderedEnqueuer( + TestSequence([3, 200, 200, 3]), use_multiprocessing=True) + enqueuer2 = keras.utils.data_utils.OrderedEnqueuer( + TestSequence([3, 200, 200, 3], value=15), use_multiprocessing=True) + enqueuer.start(3, 10) + enqueuer2.start(3, 10) + gen_output = enqueuer.get() + gen_output2 = enqueuer2.get() + acc = [] + for _ in range(100): + acc.append(next(gen_output)[0, 0, 0, 0]) + self.assertEqual(acc[-1], 99) + # One epoch is completed so enqueuer will switch the Sequence + + acc = [] + for _ in range(100): + acc.append(next(gen_output2)[0, 0, 0, 0]) + self.assertEqual(acc[-1], 99 * 15) + # One epoch has been completed so enqueuer2 will switch + + # Be sure that both Sequence were updated + self.assertEqual(next(gen_output)[0, 0, 0, 0], 0) + self.assertEqual(next(gen_output)[0, 0, 0, 0], 5) + self.assertEqual(next(gen_output2)[0, 0, 0, 0], 0) + self.assertEqual(next(gen_output2)[0, 0, 0, 0], 15 * 5) + + # Tear down everything + enqueuer.stop() + enqueuer2.stop() + + def test_on_epoch_end_threads(self): + enqueuer = keras.utils.data_utils.OrderedEnqueuer( + TestSequence([3, 200, 200, 3]), use_multiprocessing=False) + enqueuer.start(3, 10) + gen_output = enqueuer.get() + acc = [] + for _ in range(100): + acc.append(next(gen_output)[0, 0, 0, 0]) + acc = [] + for _ in range(100): + acc.append(next(gen_output)[0, 0, 0, 0]) + # Check that order was keep in GeneratorEnqueuer with processes + self.assertEqual(acc, list([k * 5 for k in range(100)])) + enqueuer.stop() + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py index 39a10c8650f..025e5d30a59 100644 --- a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py @@ -43,7 +43,7 @@ class CustomObjectScope(object): Example: - Consider a custom object `MyObject` + Consider a custom object `MyObject` (e.g. a class): ```python with CustomObjectScope({'MyObject':MyObject}): @@ -271,6 +271,9 @@ class Progbar(object): self.total_width = 0 self.seen_so_far = 0 self.verbose = verbose + self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and + sys.stdout.isatty()) or + 'ipykernel' in sys.modules) def update(self, current, values=None, force=False): """Updates the progress bar. @@ -294,18 +297,23 @@ class Progbar(object): self.seen_so_far = current now = time.time() + info = ' - %.0fs' % (now - self.start) if self.verbose == 1: - if not force and (now - self.last_update) < self.interval: + if (not force and (now - self.last_update) < self.interval and + current < self.target): return prev_total_width = self.total_width - sys.stdout.write('\b' * prev_total_width) - sys.stdout.write('\r') + if self._dynamic_display: + sys.stdout.write('\b' * prev_total_width) + sys.stdout.write('\r') + else: + sys.stdout.write('\n') - if self.target is not -1: + if self.target is not None: numdigits = int(np.floor(np.log10(self.target))) + 1 - barstr = '%%%dd/%%%dd [' % (numdigits, numdigits) - bar = barstr % (current, self.target) + barstr = '%%%dd/%d [' % (numdigits, self.target) + bar = barstr % current prog = float(current) / self.target prog_width = int(self.width * prog) if prog_width > 0: @@ -318,17 +326,35 @@ class Progbar(object): bar += ']' sys.stdout.write(bar) self.total_width = len(bar) + else: + bar = '%7d/Unknown' % current + + self.total_width = len(bar) + sys.stdout.write(bar) if current: time_per_unit = (now - self.start) / current else: time_per_unit = 0 - eta = time_per_unit * (self.target - current) - info = '' - if current < self.target and self.target is not -1: - info += ' - ETA: %ds' % eta + if self.target is not None and current < self.target: + eta = time_per_unit * (self.target - current) + if eta > 3600: + eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) // 60, + eta % 60) + elif eta > 60: + eta_format = '%d:%02d' % (eta // 60, eta % 60) + else: + eta_format = '%ds' % eta + + info = ' - ETA: %s' % eta_format else: - info += ' - %ds' % (now - self.start) + if time_per_unit >= 1: + info += ' %.0fs/step' % time_per_unit + elif time_per_unit >= 1e-3: + info += ' %.0fms/step' % (time_per_unit * 1e3) + else: + info += ' %.0fus/step' % (time_per_unit * 1e6) + for k in self.unique_values: info += ' - %s:' % k if isinstance(self.sum_values[k], list): @@ -342,7 +368,9 @@ class Progbar(object): self.total_width += len(info) if prev_total_width > self.total_width: - info += ((prev_total_width - self.total_width) * ' ') + info += (' ' * (prev_total_width - self.total_width)) + if self.target is not None and current >= self.target: + info += '\n' sys.stdout.write(info) sys.stdout.flush() @@ -350,17 +378,20 @@ class Progbar(object): if current >= self.target: sys.stdout.write('\n') - if self.verbose == 2: - if current >= self.target: - info = '%ds' % (now - self.start) + elif self.verbose == 2: + if self.target is None or current >= self.target: for k in self.unique_values: info += ' - %s:' % k - avg = np.mean(self.sum_values[k][0] / max(1, self.sum_values[k][1])) + avg = np.mean( + self.sum_values[k][0] / max(1, self.sum_values[k][1])) if avg > 1e-3: info += ' %.4f' % avg else: info += ' %.4e' % avg - sys.stdout.write(info + '\n') + info += '\n' + + sys.stdout.write(info) + sys.stdout.flush() self.last_update = now diff --git a/tensorflow/python/keras/_impl/keras/utils/io_utils.py b/tensorflow/python/keras/_impl/keras/utils/io_utils.py index 5f2ba99be78..1c8299c27d2 100644 --- a/tensorflow/python/keras/_impl/keras/utils/io_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/io_utils.py @@ -84,7 +84,7 @@ class HDF5Matrix(object): if start is None: start = 0 if stop is None: - stop = self.data.shape[0] + stop = self.shape[0] if stop + self.start <= self.end: idx = slice(start + self.start, stop + self.start) else: diff --git a/tensorflow/python/keras/_impl/keras/utils/layer_utils.py b/tensorflow/python/keras/_impl/keras/utils/layer_utils.py index 86c02643556..053c0600a33 100644 --- a/tensorflow/python/keras/_impl/keras/utils/layer_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/layer_utils.py @@ -24,6 +24,18 @@ from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.utils.conv_utils import convert_kernel +def count_params(weights): + """Count the total number of scalars composing the weights. + + Arguments: + weights: An iterable containing the weights on which to compute params + + Returns: + The total number of scalars composing the weights + """ + return int(np.sum([K.count_params(p) for p in set(weights)])) + + def print_summary(model, line_length=None, positions=None, print_fn=None): """Prints a summary of a model. @@ -46,12 +58,28 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): sequential_like = True else: sequential_like = True - for v in model._nodes_by_depth.values(): # pylint: disable=protected-access + nodes_by_depth = model._nodes_by_depth.values() # pylint: disable=protected-access + nodes = [] + for v in nodes_by_depth: if (len(v) > 1) or (len(v) == 1 and len(v[0].inbound_layers) > 1): # If the model has multiple nodes or if the nodes have # multiple inbound_layers, the model is no longer sequential. sequential_like = False break + nodes += v + if sequential_like: + # search for shared layers + for layer in model.layers: + flag = False + for node in layer.inbound_nodes: + if node in nodes: + if flag: + sequential_like = False + break + else: + flag = True + if not sequential_like: + break if sequential_like: line_length = line_length or 65 @@ -61,7 +89,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): # header names for the different log elements to_display = ['Layer (type)', 'Output Shape', 'Param #'] else: - line_length = line_length or 100 + line_length = line_length or 98 positions = positions or [.33, .55, .67, 1.] if positions[-1] <= 1: positions = [int(line_length * p) for p in positions] @@ -144,8 +172,12 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): else: print_fn('_' * line_length) - trainable_count = int( - np.sum([K.count_params(p) for p in set(model.trainable_weights)])) + model._check_trainable_weights_consistency() # pylint: disable=protected-access + if hasattr(model, '_collected_trainable_weights'): + trainable_count = count_params(model._collected_trainable_weights) # pylint: disable=protected-access + else: + trainable_count = count_params(model.trainable_weights) + non_trainable_count = int( np.sum([K.count_params(p) for p in set(model.non_trainable_weights)])) diff --git a/tensorflow/python/keras/_impl/keras/utils/np_utils.py b/tensorflow/python/keras/_impl/keras/utils/np_utils.py index a23172d342a..896016d4d8b 100644 --- a/tensorflow/python/keras/_impl/keras/utils/np_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/np_utils.py @@ -33,12 +33,18 @@ def to_categorical(y, num_classes=None): Returns: A binary matrix representation of the input. """ - y = np.array(y, dtype='int').ravel() + y = np.array(y, dtype='int') + input_shape = y.shape + if input_shape and input_shape[-1] == 1: + input_shape = tuple(input_shape[:-1]) + y = y.ravel() if not num_classes: num_classes = np.max(y) + 1 n = y.shape[0] categorical = np.zeros((n, num_classes)) categorical[np.arange(n), y] = 1 + output_shape = input_shape + (num_classes,) + categorical = np.reshape(categorical, output_shape) return categorical diff --git a/tensorflow/python/keras/_impl/keras/utils/np_utils_test.py b/tensorflow/python/keras/_impl/keras/utils/np_utils_test.py new file mode 100644 index 00000000000..9680c295cd3 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/utils/np_utils_test.py @@ -0,0 +1,52 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for np_utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.keras._impl import keras +from tensorflow.python.platform import test + + +class TestNPUtils(test.TestCase): + + def test_to_categorical(self): + num_classes = 5 + shapes = [(3,), (4, 3), (5, 4, 3), (3, 1), (3, 2, 1)] + expected_shapes = [(3, num_classes), + (4, 3, num_classes), + (5, 4, 3, num_classes), + (3, num_classes)] + labels = [np.random.randint(0, num_classes, shape) for shape in shapes] + one_hots = [ + keras.utils.to_categorical(label, num_classes) for label in labels] + for label, one_hot, expected_shape in zip(labels, + one_hots, + expected_shapes): + # Check shape + self.assertEqual(one_hot.shape, expected_shape) + # Make sure there is only one 1 in a row + self.assertTrue(np.all(one_hot.sum(axis=-1) == 1)) + # Get original labels back from one hots + self.assertTrue(np.all( + np.argmax(one_hot, -1).reshape(label.shape) == label)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/_impl/keras/utils/training_utils.py b/tensorflow/python/keras/_impl/keras/utils/training_utils.py index b993a16394d..8939c814cf3 100644 --- a/tensorflow/python/keras/_impl/keras/utils/training_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/training_utils.py @@ -77,8 +77,11 @@ def multi_gpu_model(model, gpus): width = 224 num_classes = 1000 - # Instantiate the base model - # (here, we do it on CPU, for better efficiency). + # Instantiate the base model (or "template" model). + # We recommend doing this with under a CPU device scope, + # so that the model's weights are hosted on CPU memory. + # Otherwise they may end up hosted on a GPU, which would + # complicate weight sharing. with tf.device('/cpu:0'): model = Xception(weights=None, input_shape=(height, width, 3), @@ -97,6 +100,9 @@ def multi_gpu_model(model, gpus): # This `fit` call will be distributed on 8 GPUs. # Since the batch size is 256, each GPU will process 32 samples. parallel_model.fit(x, y, epochs=20, batch_size=256) + + # Save model via the template model (which shares the same weights): + model.save('my_model.h5') ``` Raises: diff --git a/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py b/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py index ac7bd494062..31ef4773ad6 100644 --- a/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py +++ b/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py @@ -352,5 +352,5 @@ class KerasRegressor(BaseWrapper): kwargs = self.filter_sk_params(Sequential.evaluate, kwargs) loss = self.model.evaluate(x, y, **kwargs) if isinstance(loss, list): - return loss[0] - return loss + return -loss[0] + return -loss diff --git a/tensorflow/python/keras/datasets/__init__.py b/tensorflow/python/keras/datasets/__init__.py index b76f278964b..69e10bd63c7 100644 --- a/tensorflow/python/keras/datasets/__init__.py +++ b/tensorflow/python/keras/datasets/__init__.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.keras.datasets import boston_housing from tensorflow.python.keras.datasets import cifar10 from tensorflow.python.keras.datasets import cifar100 +from tensorflow.python.keras.datasets import fashion_mnist from tensorflow.python.keras.datasets import imdb from tensorflow.python.keras.datasets import mnist from tensorflow.python.keras.datasets import reuters diff --git a/tensorflow/python/keras/datasets/fashion_mnist/__init__.py b/tensorflow/python/keras/datasets/fashion_mnist/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py index 76c790a0a20..d4bc71f1c8e 100644 --- a/tensorflow/python/kernel_tests/lookup_ops_test.py +++ b/tensorflow/python/kernel_tests/lookup_ops_test.py @@ -281,6 +281,37 @@ class IndexTableFromFile(test.TestCase): lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) + def test_string_index_table_from_multicolumn_file(self): + vocabulary_file = self._createVocabFile( + "f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1")) + with self.test_session(): + table = lookup_ops.index_table_from_file( + vocabulary_file=vocabulary_file, + num_oov_buckets=1, + key_column_index=0, + value_column_index=lookup_ops.TextFileIndex.LINE_NUMBER) + ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) + + self.assertRaises(errors_impl.OpError, ids.eval) + lookup_ops.tables_initializer().run() + self.assertAllEqual((1, 2, 3), ids.eval()) + + def test_string_index_table_from_multicolumn_file_custom_delimiter(self): + vocabulary_file = self._createVocabFile( + "f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1")) + with self.test_session(): + table = lookup_ops.index_table_from_file( + vocabulary_file=vocabulary_file, + num_oov_buckets=1, + key_column_index=0, + value_column_index=lookup_ops.TextFileIndex.LINE_NUMBER, + delimiter=" ") + ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) + + self.assertRaises(errors_impl.OpError, ids.eval) + lookup_ops.tables_initializer().run() + self.assertAllEqual((1, 2, 3), ids.eval()) + def test_string_index_table_from_file_tensor_filename(self): vocabulary_file = self._createVocabFile("f2i_vocab1.txt") with self.test_session(): @@ -566,10 +597,10 @@ class IndexTableFromTensor(test.TestCase): class IndexToStringTableFromFileTest(test.TestCase): - def _createVocabFile(self, basename): + def _createVocabFile(self, basename, values=("brain", "salad", "surgery")): vocabulary_file = os.path.join(self.get_temp_dir(), basename) with open(vocabulary_file, "w") as f: - f.write("\n".join(["brain", "salad", "surgery"]) + "\n") + f.write("\n".join(values) + "\n") return vocabulary_file def test_index_to_string_table(self): @@ -583,6 +614,35 @@ class IndexToStringTableFromFileTest(test.TestCase): self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), features.eval()) + def test_index_to_string_table_from_multicolumn_file(self): + vocabulary_file = self._createVocabFile( + "f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1")) + with self.test_session(): + table = lookup_ops.index_to_string_table_from_file( + vocabulary_file=vocabulary_file, + key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER, + value_column_index=0) + features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) + self.assertRaises(errors_impl.OpError, features.eval) + lookup_ops.tables_initializer().run() + self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), + features.eval()) + + def test_index_to_string_table_from_multicolumn_file_custom_delimiter(self): + vocabulary_file = self._createVocabFile( + "f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1")) + with self.test_session(): + table = lookup_ops.index_to_string_table_from_file( + vocabulary_file=vocabulary_file, + key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER, + value_column_index=0, + delimiter=" ") + features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) + self.assertRaises(errors_impl.OpError, features.eval) + lookup_ops.tables_initializer().run() + self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), + features.eval()) + def test_index_to_string_table_with_default_value(self): default_value = b"NONE" vocabulary_file = self._createVocabFile("f2i_vocab2.txt") diff --git a/tensorflow/python/kernel_tests/metrics_test.py b/tensorflow/python/kernel_tests/metrics_test.py index 971dc9d5530..3358b78efd2 100644 --- a/tensorflow/python/kernel_tests/metrics_test.py +++ b/tensorflow/python/kernel_tests/metrics_test.py @@ -3857,6 +3857,56 @@ class MeanPerClassAccuracyTest(test.TestCase): self.assertAlmostEqual(desired_mean_accuracy, mean_accuracy.eval()) +class FalseNegativesTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.false_negatives( + labels=(0, 1, 0, 1), + predictions=(0, 0, 1, 1)) + _assert_metric_variables(self, ('false_negatives/count:0',)) + + def testUnweighted(self): + labels = constant_op.constant(((0, 1, 0, 1, 0), + (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), + (0, 0, 0, 0, 1))) + predictions = constant_op.constant(((0, 0, 1, 1, 0), + (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), + (1, 1, 1, 1, 1))) + tn, tn_update_op = metrics.false_negatives( + labels=labels, predictions=predictions) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAllClose(0., tn.eval()) + self.assertAllClose(3., tn_update_op.eval()) + self.assertAllClose(3., tn.eval()) + + def testWeighted(self): + labels = constant_op.constant(((0, 1, 0, 1, 0), + (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), + (0, 0, 0, 0, 1))) + predictions = constant_op.constant(((0, 0, 1, 1, 0), + (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), + (1, 1, 1, 1, 1))) + weights = constant_op.constant((1., 1.5, 2., 2.5)) + tn, tn_update_op = metrics.false_negatives( + labels=labels, predictions=predictions, weights=weights) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAllClose(0., tn.eval()) + self.assertAllClose(5., tn_update_op.eval()) + self.assertAllClose(5., tn.eval()) + + class FalseNegativesAtThresholdsTest(test.TestCase): def setUp(self): @@ -3906,6 +3956,56 @@ class FalseNegativesAtThresholdsTest(test.TestCase): self.assertAllEqual((0.0, 8.0, 11.0), fn.eval()) +class FalsePositivesTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.false_positives( + labels=(0, 1, 0, 1), + predictions=(0, 0, 1, 1)) + _assert_metric_variables(self, ('false_positives/count:0',)) + + def testUnweighted(self): + labels = constant_op.constant(((0, 1, 0, 1, 0), + (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), + (0, 0, 0, 0, 1))) + predictions = constant_op.constant(((0, 0, 1, 1, 0), + (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), + (1, 1, 1, 1, 1))) + tn, tn_update_op = metrics.false_positives( + labels=labels, predictions=predictions) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAllClose(0., tn.eval()) + self.assertAllClose(7., tn_update_op.eval()) + self.assertAllClose(7., tn.eval()) + + def testWeighted(self): + labels = constant_op.constant(((0, 1, 0, 1, 0), + (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), + (0, 0, 0, 0, 1))) + predictions = constant_op.constant(((0, 0, 1, 1, 0), + (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), + (1, 1, 1, 1, 1))) + weights = constant_op.constant((1., 1.5, 2., 2.5)) + tn, tn_update_op = metrics.false_positives( + labels=labels, predictions=predictions, weights=weights) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAllClose(0., tn.eval()) + self.assertAllClose(14., tn_update_op.eval()) + self.assertAllClose(14., tn.eval()) + + class FalsePositivesAtThresholdsTest(test.TestCase): def setUp(self): @@ -3957,6 +4057,56 @@ class FalsePositivesAtThresholdsTest(test.TestCase): self.assertAllEqual((125.0, 42.0, 12.0), fp.eval()) +class TrueNegativesTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.true_negatives( + labels=(0, 1, 0, 1), + predictions=(0, 0, 1, 1)) + _assert_metric_variables(self, ('true_negatives/count:0',)) + + def testUnweighted(self): + labels = constant_op.constant(((0, 1, 0, 1, 0), + (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), + (0, 0, 0, 0, 1))) + predictions = constant_op.constant(((0, 0, 1, 1, 0), + (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), + (1, 1, 1, 1, 1))) + tn, tn_update_op = metrics.true_negatives( + labels=labels, predictions=predictions) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAllClose(0., tn.eval()) + self.assertAllClose(3., tn_update_op.eval()) + self.assertAllClose(3., tn.eval()) + + def testWeighted(self): + labels = constant_op.constant(((0, 1, 0, 1, 0), + (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), + (0, 0, 0, 0, 1))) + predictions = constant_op.constant(((0, 0, 1, 1, 0), + (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), + (1, 1, 1, 1, 1))) + weights = constant_op.constant((1., 1.5, 2., 2.5)) + tn, tn_update_op = metrics.true_negatives( + labels=labels, predictions=predictions, weights=weights) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAllClose(0., tn.eval()) + self.assertAllClose(4., tn_update_op.eval()) + self.assertAllClose(4., tn.eval()) + + class TrueNegativesAtThresholdsTest(test.TestCase): def setUp(self): @@ -4006,6 +4156,56 @@ class TrueNegativesAtThresholdsTest(test.TestCase): self.assertAllEqual((5.0, 15.0, 23.0), tn.eval()) +class TruePositivesTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.true_positives( + labels=(0, 1, 0, 1), + predictions=(0, 0, 1, 1)) + _assert_metric_variables(self, ('true_positives/count:0',)) + + def testUnweighted(self): + labels = constant_op.constant(((0, 1, 0, 1, 0), + (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), + (0, 0, 0, 0, 1))) + predictions = constant_op.constant(((0, 0, 1, 1, 0), + (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), + (1, 1, 1, 1, 1))) + tn, tn_update_op = metrics.true_positives( + labels=labels, predictions=predictions) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAllClose(0., tn.eval()) + self.assertAllClose(7., tn_update_op.eval()) + self.assertAllClose(7., tn.eval()) + + def testWeighted(self): + labels = constant_op.constant(((0, 1, 0, 1, 0), + (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), + (0, 0, 0, 0, 1))) + predictions = constant_op.constant(((0, 0, 1, 1, 0), + (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), + (1, 1, 1, 1, 1))) + weights = constant_op.constant((1., 1.5, 2., 2.5)) + tn, tn_update_op = metrics.true_positives( + labels=labels, predictions=predictions, weights=weights) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAllClose(0., tn.eval()) + self.assertAllClose(12., tn_update_op.eval()) + self.assertAllClose(12., tn.eval()) + + class TruePositivesAtThresholdsTest(test.TestCase): def setUp(self): diff --git a/tensorflow/python/kernel_tests/template_test.py b/tensorflow/python/kernel_tests/template_test.py index 8b9c58ac3f7..40c0ade62a8 100644 --- a/tensorflow/python/kernel_tests/template_test.py +++ b/tensorflow/python/kernel_tests/template_test.py @@ -20,7 +20,9 @@ from __future__ import print_function import traceback from tensorflow.python.client import session +from tensorflow.python.eager import context from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -50,6 +52,13 @@ def function_with_create(trainable): "dummy", shape=[1], initializer=init_ops.zeros_initializer()) +def function_with_side_create(trainable, name="side"): + """Creates a variable as a side effect using tf.get_variable.""" + variable_scope.get_variable(name, shape=[1], trainable=trainable) + return variable_scope.get_variable( + "dummy", shape=[1], initializer=init_ops.zeros_initializer()) + + def variable_scoped_function_with_local_variable(): variable_scope.get_local_variable( "local", shape=[1], initializer=init_ops.zeros_initializer()) @@ -99,6 +108,46 @@ class TemplateTest(test.TestCase): # Parameters are tied, so the loss should have gone down when we trained it. self.assertLess(final_test_loss, initial_test_loss) + def test_end_to_end_eager(self): + """This test shows a very simple line model with test_loss in eager mode. + + The template is used to share parameters between a training and test model. + """ + with context.eager_mode(): + # y = 2x + 1 + training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7]) + test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17]) + + random_seed.set_random_seed(1234) + + def test_line(x): + m = variable_scope.get_variable( + "w", shape=[], initializer=init_ops.truncated_normal_initializer()) + b = variable_scope.get_variable( + "b", shape=[], initializer=init_ops.truncated_normal_initializer()) + return x * m + b + + line_template = template.make_template("line", test_line) + + def train_loss(): + train_prediction = line_template(training_input) + return math_ops.reduce_mean( + math_ops.square(train_prediction - training_output)) + + def test_loss(): + test_prediction = line_template(test_input) + return math_ops.reduce_mean( + math_ops.square(test_prediction - test_output)) + + optimizer = gradient_descent.GradientDescentOptimizer(0.1) + initial_test_loss = test_loss() + optimizer.minimize(train_loss) + final_test_loss = test_loss() + + # Parameters are tied, so the loss should have gone down after training. + self.assertLess(final_test_loss.numpy(), initial_test_loss.numpy()) + + @test_util.run_in_graph_and_eager_modes() def test_skip_stack_frames(self): first = traceback.format_stack() second = traceback.format_stack() @@ -106,6 +155,7 @@ class TemplateTest(test.TestCase): self.assertEqual(1, len(result)) self.assertNotEqual(len(first), len(result)) + @test_util.run_in_graph_and_eager_modes() def test_template_with_name(self): tmpl1 = template.make_template("s1", variable_scoped_function) tmpl2 = template.make_template("s1", variable_scoped_function) @@ -118,15 +168,23 @@ class TemplateTest(test.TestCase): self.assertEqual("s1/dummy:0", v1.name) self.assertEqual("s1_1/dummy:0", v3.name) - def test_unique_name_raise_error(self): + def test_same_unique_name_raise_error(self): tmpl1 = template.make_template( "_", variable_scoped_function, unique_name_="s1") tmpl1() tmpl2 = template.make_template( "_", variable_scoped_function, unique_name_="s1") - with self.assertRaises(ValueError): + with self.assertRaisesRegexp( + ValueError, "Variable s1/dummy already exists, disallowed.*"): tmpl2() + def test_unique_name_raise_error_in_eager(self): + with context.eager_mode(): + with self.assertRaisesRegexp( + ValueError, "unique_name cannot be used in eager mode."): + template.make_template( + "_", variable_scoped_function, unique_name_="s1") + def test_unique_name_and_reuse(self): tmpl1 = template.make_template( "_", variable_scoped_function, unique_name_="s1") @@ -142,6 +200,7 @@ class TemplateTest(test.TestCase): self.assertEqual(v1, v3) self.assertEqual("s1/dummy:0", v1.name) + @test_util.run_in_graph_and_eager_modes() def test_template_in_scope(self): tmpl1 = template.make_template("s1", variable_scoped_function) tmpl2 = template.make_template("s1", variable_scoped_function) @@ -158,6 +217,7 @@ class TemplateTest(test.TestCase): self.assertEqual("scope/s1/dummy:0", v1.name) self.assertEqual("scope/s1_1/dummy:0", v3.name) + @test_util.run_in_graph_and_eager_modes() def test_template_with_internal_reuse(self): tmpl1 = template.make_template("s1", internally_variable_scoped_function) tmpl2 = template.make_template("s1", internally_variable_scoped_function) @@ -173,10 +233,13 @@ class TemplateTest(test.TestCase): with self.assertRaises(ValueError): tmpl1("not_test") + @test_util.run_in_graph_and_eager_modes() def test_template_without_name(self): - with self.assertRaises(ValueError): + with self.assertRaisesRegexp( + ValueError, "name cannot be None."): template.make_template(None, variable_scoped_function) + @test_util.run_in_graph_and_eager_modes() def test_make_template(self): # Test both that we can call it with positional and keywords. tmpl1 = template.make_template( @@ -199,10 +262,28 @@ class TemplateTest(test.TestCase): with self.assertRaises(ValueError): tmpl() + @test_util.run_in_graph_and_eager_modes() + def test_enforces_no_extra_trainable_variables_eager(self): + tmpl = template.make_template("s", + function_with_side_create, + trainable=True) + + tmpl(name="1") + with self.assertRaises(ValueError): + tmpl(name="2") + def test_permits_extra_non_trainable_variables(self): tmpl = template.make_template("s", function_with_create, trainable=False) self.assertEqual(tmpl(), tmpl()) + def test_permits_extra_non_trainable_variables_eager(self): + with context.eager_mode(): + tmpl = template.make_template("s", + function_with_side_create, + trainable=False) + self.assertEqual(tmpl(name="1"), tmpl(name="2")) + + @test_util.run_in_graph_and_eager_modes() def test_internal_variable_reuse(self): def nested(): @@ -241,11 +322,28 @@ class TemplateTest(test.TestCase): v1 = tmpl1() v2 = tmpl1() v3 = tmpl2() - self.assertEqual(v1, v2) + self.assertTrue(v1, v2) self.assertNotEqual(v1, v3) self.assertEqual("s1/nested_1/dummy:0", v1.name) self.assertEqual("s1_1/nested_1/dummy:0", v3.name) + def test_nested_eager_templates_raises_error(self): + + def nested_template(): + nested1 = template.make_template("nested", variable_scoped_function) + nested2 = template.make_template("nested", variable_scoped_function) + v1 = nested1() + v2 = nested2() + self.assertNotEqual(v1, v2) + return v2 + + with context.eager_mode(): + tmpl1 = template.make_template("s1", nested_template) + with self.assertRaisesRegexp( + ValueError, "Nested EagerTemaplates are not currently supported."): + tmpl1() + + @test_util.run_in_graph_and_eager_modes() def test_immediate_scope_creation(self): # Create templates in scope a then call in scope b. make_template should # capture the scope the first time it is called, and make_immediate_template @@ -270,6 +368,7 @@ class TemplateTest(test.TestCase): self.assertEqual("ctor_scope/a/dummy:0", inner_imm_var.name) self.assertEqual("call_scope/b/dummy:0", inner_defer_var.name) + @test_util.run_in_graph_and_eager_modes() def test_scope_access(self): # Ensure that we can access the scope inside the template, because the name # of that scope may be different from the name we pass to make_template, due @@ -294,6 +393,7 @@ class TemplateTest(test.TestCase): # Template is called at the top level, so there is no preceding "foo_2". self.assertEqual(tc.variable_scope.name, "blah") + @test_util.run_in_graph_and_eager_modes() def test_custom_getter(self): # Custom getter that maintains call count and forwards to true getter custom_getter_count = [0] @@ -326,6 +426,7 @@ class TemplateTest(test.TestCase): tmpl2() self.assertEqual(custom_getter_count[0], 2) + @test_util.run_in_graph_and_eager_modes() def test_fails_gracefully(self): for create_scope_now in [True, False]: def module_function_with_one_arg(inputs): @@ -336,7 +437,7 @@ class TemplateTest(test.TestCase): templatized_function = template.make_template( "f1", module_function_with_one_arg, create_scope_now_=create_scope_now) - data = array_ops.zeros(1) + data = array_ops.zeros([1]) try: # Try to connect with a kwarg which is unsupported. templatized_function(data, is_training=True) @@ -348,6 +449,7 @@ class TemplateTest(test.TestCase): templatized_function(data) self.assertTrue(templatized_function._variables_created) + @test_util.run_in_graph_and_eager_modes() def test_name_scopes_for_variable_scopes(self): # Test that name scopes are not unnecessarily uniquified (but are # still uniquified when necessary). @@ -374,12 +476,13 @@ class TemplateTest(test.TestCase): outputs_b, _ = linear1(inputs) self.assertEquals("foo", linear1.variable_scope.name) self.assertEquals("foo/w:0", w1.name) - self.assertEquals("foo/add:0", outputs_a.name, - "First application of template should get " - "same name scope as variables.") - self.assertEquals("foo_1/add:0", outputs_b.name, - "Second application of template should get " - "a freshly uniquified name scope.") + if context.in_graph_mode(): + self.assertEquals("foo/add:0", outputs_a.name, + "First application of template should get " + "same name scope as variables.") + self.assertEquals("foo_1/add:0", outputs_b.name, + "Second application of template should get " + "a freshly uniquified name scope.") linear2 = make_linear_module(output_size=2, name="foo") outputs_c, w2 = linear2(inputs) @@ -388,24 +491,30 @@ class TemplateTest(test.TestCase): "New template gets a freshly uniquified variable scope " "because 'foo' is already taken.") self.assertEquals("foo_1/w:0", w2.name) - self.assertEquals("foo_1_1/add:0", outputs_c.name, - "First application of template would get " - "same name scope as variables, but 'foo_1' is already " - "a name scope.") - self.assertEquals("foo_1_2/add:0", outputs_d.name, - "Second application of template should also get " - "a freshly uniquified name scope.") + if context.in_graph_mode(): + self.assertEquals("foo_1_1/add:0", outputs_c.name, + "First application of template would get " + "same name scope as variables, but 'foo_1' is already " + "a name scope.") + self.assertEquals("foo_1_2/add:0", outputs_d.name, + "Second application of template should also get " + "a freshly uniquified name scope.") + @test_util.run_in_graph_and_eager_modes() def test_global_variables(self): # Make sure global_variables are created. with variable_scope.variable_scope("foo"): # Create two templates with the same name, ensure scopes are made unique. ta = template.make_template("bar", variable_scoped_function, True) - tb = template.make_template("s", function_with_create, trainable=False) + if context.in_eager_mode(): + tb = template.make_template("s", function_with_side_create, + trainable=False) + else: + tb = template.make_template("s", function_with_create, trainable=False) # Initially there are not variables created. - self.assertEqual([], ta.global_variables) - self.assertEqual([], tb.global_variables) + self.assertEqual([], list(ta.global_variables)) + self.assertEqual([], list(tb.global_variables)) # After calling there are variables created. ta() tb() @@ -413,6 +522,7 @@ class TemplateTest(test.TestCase): self.assertEqual(1, len(ta.global_variables)) self.assertEqual(2, len(tb.global_variables)) + @test_util.run_in_graph_and_eager_modes() def test_trainable_variables(self): # Make sure trainable_variables are created. with variable_scope.variable_scope("foo2"): @@ -421,8 +531,8 @@ class TemplateTest(test.TestCase): tb = template.make_template("bar", variable_scoped_function, True) # Initially there are not variables created. - self.assertEqual([], ta.trainable_variables) - self.assertEqual([], tb.trainable_variables) + self.assertEqual([], list(ta.trainable_variables)) + self.assertEqual([], list(tb.trainable_variables)) # After calling there are variables created. ta() tb() @@ -430,6 +540,7 @@ class TemplateTest(test.TestCase): self.assertEqual(1, len(ta.trainable_variables)) self.assertEqual(1, len(tb.trainable_variables)) + # TODO(apassos) handle local variables in Eager def test_local_variables(self): # Make sure trainable_variables are created. with variable_scope.variable_scope("foo3"): @@ -439,8 +550,8 @@ class TemplateTest(test.TestCase): variable_scoped_function_with_local_variable) # Initially there are not variables created. - self.assertEqual([], ta.local_variables) - self.assertEqual([], tb.local_variables) + self.assertEqual([], list(ta.local_variables)) + self.assertEqual([], list(tb.local_variables)) # After calling there are variables created. ta() tb() diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py index 0f3b11e7f9f..835fdbe2aa5 100644 --- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py +++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py @@ -43,6 +43,10 @@ import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test +# TODO(ebrevdo): Delete this line after Dec. 4, 2017. +tensor_array_ops._ENABLE_IDENTICAL_ELEMENT_SHAPES = True + + def _make_converter(tf_dtype): def _converter(x): if tf_dtype == dtypes.string: @@ -186,6 +190,22 @@ class TensorArrayTest(test.TestCase): def testTensorArrayReadOrPackNotAllValuesAvailableFillsZeros(self): self._testTensorArrayReadOrPackNotAllValuesAvailableFillsZeros() + def _testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros(self): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=3) + self.assertAllEqual( + [[0.0, 0.0]], self.evaluate(ta.write(1, [[4.0, 5.0]]).read(0))) + self.assertAllEqual([[[0.0, 0.0]], [[4.0, 5.0]], [[0.0, 0.0]]], + self.evaluate(ta.write(1, [[4.0, 5.0]]).stack())) + self.assertAllEqual([[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]], + self.evaluate(ta.write(1, [[4.0, 5.0]]).concat())) + + @test_util.run_in_graph_and_eager_modes() + def testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros(self): + self._testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros() + def _testTensorArrayUnpackRead(self, tf_dtype): with self.test_session(use_gpu=True): convert = _make_converter(tf_dtype) @@ -739,7 +759,8 @@ class TensorArrayTest(test.TestCase): def testTensorArrayGradientSplitConcat(self): with self.test_session(use_gpu=True) as session: ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, tensor_array_name="foo", size=2) + dtype=dtypes.float32, tensor_array_name="foo", size=2, + infer_shape=False) value = constant_op.constant( [[1.0, -1.0], [10.0, -10.0], [100.0, -100.0]]) diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index c5bf4c6080a..6be2bc3e769 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -30,6 +30,7 @@ from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import utils as layers_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as tf_variables @@ -250,7 +251,7 @@ class Layer(object): if inputs is not None: # We compute an ID that uniquely identifies the list of tensors. # This ID is order-sensitive. - inputs_hash = _object_list_uid(inputs) + inputs_hash = layers_util.object_list_uid(inputs) else: inputs_hash = None if inputs_hash not in self._per_input_updates: @@ -279,7 +280,7 @@ class Layer(object): if not inputs: inputs = None if inputs is not None: - inputs_hash = _object_list_uid(inputs) + inputs_hash = layers_util.object_list_uid(inputs) else: inputs_hash = None return self._per_input_updates.get(inputs_hash, []) @@ -326,7 +327,7 @@ class Layer(object): if inputs is not None: # We compute an ID that uniquely identifies the list of tensors. # This ID is order-sensitive. - inputs_hash = _object_list_uid(inputs) + inputs_hash = layers_util.object_list_uid(inputs) else: inputs_hash = None if inputs_hash not in self._per_input_losses: @@ -357,7 +358,7 @@ class Layer(object): if not inputs: inputs = None if inputs is not None: - inputs_hash = _object_list_uid(inputs) + inputs_hash = layers_util.object_list_uid(inputs) else: inputs_hash = None return self._per_input_losses.get(inputs_hash, []) @@ -378,6 +379,10 @@ class Layer(object): """ return inputs + def _name_scope_name(self, current_variable_scope): + """Determines op naming for the Layer.""" + return current_variable_scope.original_name_scope + def _compute_output_shape(self, input_shape): """Computes the output shape of the layer given the input shape. @@ -402,10 +407,11 @@ class Layer(object): return input_shape def _make_unique_name(self, name_uid_map=None, avoid_names=None, - namespace=''): + namespace='', zero_based=False): base_name = _to_snake_case(self.__class__.__name__) name = _unique_layer_name(base_name, name_uid_map=name_uid_map, - avoid_names=avoid_names, namespace=namespace) + avoid_names=avoid_names, namespace=namespace, + zero_based=zero_based) return (name, base_name) def _set_scope(self, scope=None): @@ -472,7 +478,7 @@ class Layer(object): self._set_scope(None) with vs.variable_scope( self._scope, reuse=(self.built or self._reuse)) as scope: - with ops.name_scope(scope.original_name_scope): + with ops.name_scope(self._name_scope_name(scope)): variable = vs.get_variable(name, shape=shape, initializer=initializer, @@ -575,7 +581,7 @@ class Layer(object): scope_context_manager = vs.variable_scope( self._scope, reuse=self._reuse) with scope_context_manager as scope: - with ops.name_scope(scope.original_name_scope): + with ops.name_scope(self._name_scope_name(scope)): if not self.built: if not in_graph_mode: # Activity regularization is currently unsupported in Eager mode. @@ -1266,9 +1272,9 @@ class Node(object): # Following 2 properties: input and output shapes. # List of shape tuples, shapes of input_tensors. - self.input_shapes = [_static_shape(x) for x in input_tensors] + self.input_shapes = [layers_util.static_shape(x) for x in input_tensors] # List of shape tuples, shapes of output_tensors. - self.output_shapes = [_static_shape(x) for x in output_tensors] + self.output_shapes = [layers_util.static_shape(x) for x in output_tensors] # Optional keyword arguments to layer's `call`. self.arguments = arguments @@ -1326,926 +1332,6 @@ class _DeferredTensor(object): self.dtype.name) -class InputLayer(Layer): - """Layer to be used as an entry point into a Network (a graph of layers). - - It can either wrap an existing tensor (pass an `input_tensor` argument) - or create its a placeholder tensor (pass arguments `input_shape` - as well as `dtype`). - - It is generally recommend to use the functional layer API via `Input`, - (which creates an `InputLayer`) without directly using `InputLayer`. - - Arguments: - input_shape: Shape tuple (not including the batch axis), or `TensorShape` - instance (not including the batch axis). - batch_size: Optional input batch size (integer or None). - dtype: Datatype of the input. - input_tensor: Optional tensor to use as layer input - instead of creating a placeholder. - sparse: Boolean, whether the placeholder created - is meant to be sparse. - name: Name of the layer (string). - - Raises: - RuntimeError: If created in Eager mode. - """ - - def __init__(self, - input_shape=None, - batch_size=None, - dtype=dtypes.float32, - input_tensor=None, - sparse=False, - name=None): - super(InputLayer, self).__init__(dtype=dtype, name=name) - self.built = True - self.sparse = sparse - self.batch_size = batch_size - - if isinstance(input_shape, tensor_shape.TensorShape): - input_shape = tuple(input_shape.as_list()) - - if input_tensor is None: - if input_shape is not None: - batch_input_shape = (batch_size,) + tuple(input_shape) - else: - batch_input_shape = None - - if context.in_eager_mode(): - # In eager mode, create a temporary placeholder to call the layer on. - input_tensor = _DeferredTensor( - shape=batch_input_shape, - dtype=dtype, - name=self.name) - else: - # In graph mode, create a graph placeholder to call the layer on. - if sparse: - input_tensor = array_ops.sparse_placeholder( - shape=batch_input_shape, - dtype=dtype, - name=self.name) - else: - input_tensor = array_ops.placeholder( - shape=batch_input_shape, - dtype=dtype, - name=self.name) - - # For compatibility with Keras API. - self.is_placeholder = True - self._batch_input_shape = batch_input_shape - else: - # For compatibility with Keras API. - self.is_placeholder = False - self._batch_input_shape = tuple(input_tensor.get_shape().as_list()) - - # Create an input node to add to self.outbound_node - # and set output_tensors' _keras_history. - input_tensor._keras_history = (self, 0, 0) # pylint: disable=protected-access - Node( - self, - inbound_layers=[], - node_indices=[], - tensor_indices=[], - input_tensors=[input_tensor], - output_tensors=[input_tensor]) - - -def Input( # pylint: disable=invalid-name - shape=None, - batch_size=None, - name=None, - dtype=dtypes.float32, - sparse=False, - tensor=None): - """`Input()` is used to instantiate an input tensor for use with a `Network`. - - For instance, if a, b and c are tensors created via `Input`, - it becomes possible to do: - - `network = Network(inputs=[a, b], outputs=c)` - - Example: - - ```python - # This is a logistic regression - x = tf.layers.Input(shape=(32,)) - y = tf.layers.Dense(16, activation='softmax')(x) - network = tf.layers.Network(x, y) - ``` - - Arguments: - shape: A shape tuple (integer), not including the batch size. - For instance, `shape=(32,)` indicates that the expected input - will be batches of 32-dimensional vectors. - batch_size: Optional input batch size (integer or None). - name: An optional name string for the layer. - Should be unique in a model (do not reuse the same name twice). - It will be autogenerated if it isn't provided. - dtype: The data type expected by the input, as a string - (`float32`, `float64`, `int32`...) - sparse: A boolean specifying whether the placeholder - to be created is sparse. - tensor: Optional existing tensor to wrap into the `Input` layer. - If set, the layer will not create a placeholder tensor. - - Returns: - A tensor: either a new placeholder (with history metadata) or - `tensor` (if passed), with added history metadata. - - Raises: - RuntimeError: If called in Eager mode. - """ - input_layer = InputLayer( - input_shape=shape, - batch_size=batch_size, - name=name, - dtype=dtype, - sparse=sparse, - input_tensor=tensor) - # Return tensor including `_keras_history` metadata. - # Note that in this case train_output and test_output are the same pointer. - outputs = input_layer._inbound_nodes[0].output_tensors # pylint: disable=protected-access - if len(outputs) == 1: - return outputs[0] - else: - return outputs - - -class Network(Layer): - """A Network is a directed acyclic graph of layers. - - It is the topological form of a "model". - A Model is simply a Network with added training/evaluation routines. - - A Network instance implements the full Layer API. In particular, a network - can be called on new inputs. - - Example: - - ```python - # This is a logistic regression - x = tf.layers.Input(shape=(32,)) - y = tf.layers.Dense(16, activation='softmax')(x) - network = tf.layers.Network(x, y) - - # It is then possible to call the network on compatible inputs: - z = tf.layers.Input(shape=(32,)) - w = network(z) - - # It is possible to retrieve the same properties as a layer: - weights = network.trainable_weights - ``` - - Arguments: - inputs: Input tensor or list of input tensors. - Must come from `tf.layers.Input`. - output: Output tensor or list of output tensors. Must come from - tf.layers Layers or Keras layers. - name: Optional name of the model (string). - - Attributes: - Network has the same attributes as Layer. On top of it, it also has: - - layers: a list of the children layers of the network, - a list of layer instances, ordered from "earlier in the graph" - to "later in the graph". - - Methods: - Network has the same methods as Layer. On top of it, it also has: - - get_layer: retrieves a child layer by name or index in the graph. - - Raises: - RuntimeError: If created in Eager mode. - """ - - def __init__(self, inputs, outputs, name=None): # pylint: disable=super-init-not-called - if context.in_eager_mode(): - # TODO(fchollet): check that all inputs and outputs are DeferredTensors. - pass - - self._init_set_name(name) - self._activity_regularizer = None - with vs.variable_scope( - None, default_name=self._base_name) as captured_scope: - self._scope = captured_scope - call_fn_args = estimator_util.fn_args(self.call) - self._compute_previous_mask = ('mask' in call_fn_args or - hasattr(self, 'compute_mask')) - self._call_has_scope_arg = 'scope' in call_fn_args - - # This acts just like the `trainable` attribute of any layer instance. - # It does not affect users of the underlying layers, only users of the - # Network instance. - self.trainable = True - # A Network does not create weights of its own, thus it is already built. - self.built = True - # A Network does not create weights of its own, thus has no dtype. - self._dtype = None - # The following are implemented as property functions: - # self.trainable_weights - # self.non_trainable_weights - # self.input_spec - - # Private attributes to implement compatibility with Layer. - self._per_input_losses = {} - self._per_input_updates = {} - self._updates = [] - self._losses = [] - self._scope = None - self._reuse = None - self._graph = ops.get_default_graph() - - # Network-specific properties. - if isinstance(inputs, (list, tuple)): - self.inputs = list(inputs) # Tensor or list of tensors. - else: - self.inputs = [inputs] - if isinstance(outputs, (list, tuple)): - self.outputs = list(outputs) - else: - self.outputs = [outputs] - # All layers in order of horizontal graph traversal. - # Entries are unique. Includes input and output layers. - self.layers = [] - - # Check for redundancy in inputs. - if len(set(self.inputs)) != len(self.inputs): - raise ValueError('The list of inputs passed to the model ' - 'is redundant. ' - 'All inputs should only appear once.' - ' Found: ' + str(self.inputs)) - - # # List of initial layers (1 to 1 mapping with self.inputs, - # # hence the same layer might appear twice) - # self._input_layers = [] - # self._input_layers_node_indices = [] - # self._input_layers_tensor_indices = [] - # # list of layers (1 to 1 mapping with self.inputs, - # # hence the same layer might appear twice) - # self._output_layers = [] - # self._output_layers_node_indices = [] - # self._output_layers_tensor_indices = [] - - self._input_layers = [] - self._output_layers = [] - self._input_coordinates = [] - self._output_coordinates = [] - - # This is for performance optimization - # when calling the Network on new inputs. - # every time the Network is called on a set on input tensors, - # we compute the output tensors, - # output masks and output shapes in one pass, - # then cache them here. When any of these outputs is queried later, - # we retrieve it from there instead of recomputing it. - self._output_mask_cache = {} - self._output_tensor_cache = {} - self._output_shape_cache = {} - - # User-provided arguments validation. - for x in self.inputs: - # Check that x has appropriate `_keras_history` metadata. - if not hasattr(x, '_keras_history'): - cls_name = self.__class__.__name__ - raise ValueError('Input tensors to a ' + cls_name + ' ' + - 'must come from `tf.layers.Input`. ' - 'Received: ' + str(x) + - ' (missing previous layer metadata).') - # Check that x is an input tensor. - # pylint: disable=protected-access - layer, node_index, tensor_index = x._keras_history - if len(layer._inbound_nodes) > 1 or ( - layer._inbound_nodes and layer._inbound_nodes[0].inbound_layers): - cls_name = self.__class__.__name__ - logging.warning(cls_name + ' inputs must come from ' - '`tf.layers.Input` (thus holding past layer metadata), ' - 'they cannot be the output of ' - 'a previous non-Input layer. ' - 'Here, a tensor specified as ' - 'input to "' + self.name + '" was not an Input tensor, ' - 'it was generated by layer ' + layer.name + '.\n' - 'Note that input tensors are ' - 'instantiated via `tensor = tf.layers.Input(shape)`.\n' - 'The tensor that caused the issue was: ' + str(x.name)) - # pylint: enable=protected-access - for x in self.outputs: - if not hasattr(x, '_keras_history'): - cls_name = self.__class__.__name__ - raise ValueError('Output tensors to a ' + cls_name + ' must be ' - 'the output of a TensorFlow `Layer` ' - '(thus holding past layer metadata). Found: ' + str(x)) - - # Build self._output_layers: - for x in self.outputs: - layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access - self._output_layers.append(layer) - self._output_coordinates.append((layer, node_index, tensor_index)) - - # Build self._input_layers: - for x in self.inputs: - layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access - # It's supposed to be an input layer, so only one node - # and one tensor output. - assert node_index == 0 - assert tensor_index == 0 - self._input_layers.append(layer) - self._input_coordinates.append((layer, node_index, tensor_index)) - - # Network_nodes: set of nodes included in the graph - # (not all nodes included in the layers - # are relevant to the current graph). - network_nodes = set() # ids of all nodes relevant to the Network - nodes_depths = {} # dict {node: depth value} - layers_depths = {} # dict {layer: depth value} - layer_indices = {} # dict {layer: index in traversal} - nodes_in_decreasing_depth = [] - - def build_map_of_graph(tensor, - finished_nodes, - nodes_in_progress, - layer, - node_index, - tensor_index): - """Builds a map of the graph of layers. - - This recursively updates the map `layer_indices`, - the list `nodes_in_decreasing_depth` and the set `network_nodes`. - - Arguments: - tensor: Some tensor in a graph. - finished_nodes: Set of nodes whose subgraphs have been traversed - completely. Useful to prevent duplicated work. - nodes_in_progress: Set of nodes that are currently active on the - recursion stack. Useful to detect cycles. - layer: Layer from which `tensor` comes from. If not provided, - will be obtained from `tensor._keras_history`. - node_index: Node index from which `tensor` comes from. - tensor_index: Tensor_index from which `tensor` comes from. - - Raises: - ValueError: if a cycle is detected. - """ - node = layer._inbound_nodes[node_index] # pylint: disable=protected-access - - # Prevent cycles. - if node in nodes_in_progress: - raise ValueError('The tensor ' + str(tensor) + ' at layer "' + - layer.name + '" is part of a cycle.') - - # Don't repeat work for shared subgraphs - if node in finished_nodes: - return - - node_key = _make_node_key(layer.name, node_index) - # Update network_nodes. - network_nodes.add(node_key) - - # Store the traversal order for layer sorting. - if layer not in layer_indices: - layer_indices[layer] = len(layer_indices) - - nodes_in_progress.add(node) - - # Propagate to all previous tensors connected to this node. - for i in range(len(node.inbound_layers)): - x = node.input_tensors[i] - layer = node.inbound_layers[i] - node_index = node.node_indices[i] - tensor_index = node.tensor_indices[i] - build_map_of_graph(x, finished_nodes, nodes_in_progress, layer, - node_index, tensor_index) - - finished_nodes.add(node) - nodes_in_progress.remove(node) - nodes_in_decreasing_depth.append(node) - - finished_nodes = set() - nodes_in_progress = set() - for x in self.outputs: - layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access - build_map_of_graph(x, finished_nodes, nodes_in_progress, - layer=layer, - node_index=node_index, - tensor_index=tensor_index) - - for node in reversed(nodes_in_decreasing_depth): - # If the depth is not set, the node has no outbound nodes (depth 0). - depth = nodes_depths.setdefault(node, 0) - - # Update the depth of the corresponding layer - previous_depth = layers_depths.get(node.outbound_layer, 0) - # If we've seen this layer before at a higher depth, - # we should use that depth instead of the node depth. - # This is necessary for shared layers that have inputs at different - # depth levels in the graph. - depth = max(depth, previous_depth) - layers_depths[node.outbound_layer] = depth - nodes_depths[node] = depth - - # Update the depth of inbound nodes. - # The "depth" of a node is the max of the depths - # of all layers it is connected to. - for i in range(len(node.inbound_layers)): - inbound_layer = node.inbound_layers[i] - node_index = node.node_indices[i] - inbound_node = inbound_layer._inbound_nodes[node_index] # pylint: disable=protected-access - previous_depth = nodes_depths.get(inbound_node, 0) - nodes_depths[inbound_node] = max(depth + 1, previous_depth) - - # Build a dict {depth: list of nodes with this depth} - nodes_by_depth = {} - for node, depth in nodes_depths.items(): - if depth not in nodes_by_depth: - nodes_by_depth[depth] = [] - nodes_by_depth[depth].append(node) - - # Build a dict {depth: list of layers with this depth} - layers_by_depth = {} - for layer, depth in layers_depths.items(): - if depth not in layers_by_depth: - layers_by_depth[depth] = [] - layers_by_depth[depth].append(layer) - - # Get sorted list of layer depths. - depth_keys = list(layers_by_depth.keys()) - depth_keys.sort(reverse=True) - - # Set self.layers and self._layers_by_depth. - layers = [] - for depth in depth_keys: - layers_for_depth = layers_by_depth[depth] - # Network.layers needs to have a deterministic order: - # here we order them by traversal order. - layers_for_depth.sort(key=lambda x: layer_indices[x]) - layers.extend(layers_for_depth) - self.layers = layers - self._layers_by_depth = layers_by_depth - - # Get sorted list of node depths. - depth_keys = list(nodes_by_depth.keys()) - depth_keys.sort(reverse=True) - - # Check that all tensors required are computable. - # computable_tensors: all tensors in the graph - # that can be computed from the inputs provided. - computable_tensors = [] - for x in self.inputs: - computable_tensors.append(x) - - layers_with_complete_input = [] # To provide a better error msg. - for depth in depth_keys: - for node in nodes_by_depth[depth]: - layer = node.outbound_layer - if layer: - for x in node.input_tensors: - if x not in computable_tensors: - raise ValueError('Graph disconnected: ' - 'cannot obtain value for tensor ' + str(x) + - ' at layer "' + layer.name + '". ' - 'The following previous layers ' - 'were accessed without issue: ' + - str(layers_with_complete_input)) - for x in node.output_tensors: - computable_tensors.append(x) - layers_with_complete_input.append(layer.name) - - # Keep track of the network's nodes. - self._network_nodes = network_nodes - self._nodes_by_depth = nodes_by_depth - - # Ensure name unicity, which will be crucial for serialization - # (since serialized nodes refer to layers by their name). - all_names = [layer.name for layer in self.layers] - for name in all_names: - if all_names.count(name) != 1: - raise ValueError('The name "' + name + '" is used ' + - str(all_names.count(name)) + ' times in the model. ' - 'All layer names should be unique.') - - # Layer parameters. - # The new network starts with a single inbound node - # for its inputs, and no outbound nodes. - self._outbound_nodes = [] # Will be appended to by future calls to __call__ - self._inbound_nodes = [ - ] # Will be appended to below, and by future calls to __call__ - # Create the node linking internal inputs to internal outputs. - Node( - outbound_layer=self, - inbound_layers=[], - node_indices=[], - tensor_indices=[], - input_tensors=self.inputs, - output_tensors=self.outputs) - - def get_layer(self, name=None, index=None): - """Retrieves a layer based on either its name (unique) or index. - - Indices are based on order of horizontal graph traversal (bottom-up). - - Arguments: - name: String, name of layer. - index: Integer, index of layer. - - Returns: - A layer instance. - - Raises: - ValueError: In case of invalid layer name or index. - """ - # TODO(fchollet): We could build a dictionary based on layer names - # since they are constant, but we have not done that yet. - if index is not None: - if len(self.layers) <= index: - raise ValueError('Was asked to retrieve layer at index ' + str(index) + - ' but model only has ' + str(len(self.layers)) + - ' layers.') - else: - return self.layers[index] - else: - if not name: - raise ValueError('Provide either a layer name or layer index.') - for layer in self.layers: - if layer.name == name: - return layer - raise ValueError('No such layer: ' + name) - - @property - def updates(self): - """Retrieve the network's updates. - - Will only include updates that are either - unconditional, or conditional on inputs to this model - (e.g. will not include updates that depend on tensors - that aren't inputs to this model). - - Returns: - A list of update ops. - """ - updates = [] - for layer in self.layers: - if hasattr(layer, 'updates'): - # Collect updates that are dependent on inputs - # that are part of the model. - for node_index, node in enumerate(layer._inbound_nodes): # pylint: disable=protected-access - node_key = _make_node_key(layer.name, node_index) - if node_key in self._network_nodes: - # The model owns this layer node. - inputs = node.input_tensors - updates += layer.get_updates_for(inputs) - # Collect unconditional updates. - updates += layer.get_updates_for(None) - return updates - - @property - def losses(self): - """Retrieve the network's losses. - - Will only include losses that are either - unconditional, or conditional on inputs to this model - (e.g. will not include losses that depend on tensors - that aren't inputs to this model). - - Returns: - A list of loss tensors. - """ - losses = [] - # Retrieve losses for all internal layers. - for layer in self.layers: - if hasattr(layer, 'losses'): - # Collect losses that are dependent on inputs - # that are part of the model. - for node_index, node in enumerate(layer._inbound_nodes): # pylint: disable=protected-access - node_key = _make_node_key(layer.name, node_index) - if node_key in self._network_nodes: - # The model owns this layer node. - inputs = node.input_tensors - losses += layer.get_losses_for(inputs) - # Collect unconditional losses. - losses += layer.get_losses_for(None) - # Add any potential unconditional model-level loss. - losses += self.get_losses_for(None) - return losses - - @property - def trainable_weights(self): - if not self.trainable: - return [] - weights = [] - for layer in self.layers: - weights += layer.trainable_weights - return weights - - @property - def non_trainable_weights(self): - weights = [] - for layer in self.layers: - weights += layer.non_trainable_weights - if not self.trainable: - trainable_weights = [] - for layer in self.layers: - trainable_weights += layer.trainable_weights - return trainable_weights + weights - return weights - - @property - def input_spec(self): - """Gets the network's input specs. - - Returns: - A list of `InputSpec` instances (one per input to the model) - or a single instance if the model has only one input. - """ - specs = [] - for layer in self._input_layers: - if layer.input_spec is None: - specs.append(None) - else: - if not isinstance(layer.input_spec, list): - raise TypeError('Layer ' + layer.name + - ' has an input_spec attribute that ' - 'is not a list. We expect a list. ' - 'Found input_spec = ' + str(layer.input_spec)) - specs += layer.input_spec - if len(specs) == 1: - return specs[0] - return specs - - def call(self, inputs, mask=None): - """Call the model on new inputs. - - In this case `call` just reapplies - all ops in the graph to the new inputs - (e.g. build a new computational graph from the provided inputs). - - Arguments: - inputs: A tensor or list of tensors. - mask: A mask or list of masks. A mask can be - either a tensor or None (no mask). - - Returns: - A tensor if there is a single output, or - a list of tensors if there are more than one outputs. - """ - inputs = nest.flatten(inputs) - if mask is None: - masks = [None for _ in range(len(inputs))] - else: - masks = nest.flatten(mask) - - if context.in_graph_mode(): - # Try to retrieve cached outputs if the layer has already been called - # on these exact inputs. - cache_key = _object_list_uid(inputs) + '_' + _object_list_uid(masks) - if cache_key in self._output_tensor_cache: - # Cache hit. - return self._output_tensor_cache[cache_key] - # Actually apply the network graph to the new inputs. - outputs, _ = self._run_internal_graph(inputs, masks) - return outputs - - def _compute_output_shape(self, input_shape): - if isinstance(input_shape, list): - input_shapes = [] - for shape in input_shape: - if shape is not None: - input_shapes.append(tuple(tensor_shape.TensorShape(shape).as_list())) - else: - input_shapes.append(None) - else: - if input_shape is not None: - input_shapes = [tuple(tensor_shape.TensorShape(input_shape).as_list())] - else: - input_shapes = [None] - - if len(input_shapes) != len(self._input_layers): - raise ValueError('Invalid input_shape argument ' + str(input_shape) + - ': model has ' + str(len(self._input_layers)) + - ' tensor inputs.') - - cache_key = _object_list_uid(input_shapes) - if cache_key not in self._output_shape_cache: - # Cache miss. We have to run the network graph manually (recursive calls - # to `_compute_output_shape`). - layers_to_output_shapes = {} - for i in range(len(input_shapes)): - layer = self._input_layers[i] - input_shape = input_shapes[i] - # It's an input layer: then `_compute_output_shape` is identity, - # and there is only one node and one tensor output. - shape_key = layer.name + '_0_0' - layers_to_output_shapes[shape_key] = input_shape - - depth_keys = list(self._nodes_by_depth.keys()) - depth_keys.sort(reverse=True) - # Iterate over nodes, by depth level. - if len(depth_keys) > 1: - for depth in depth_keys: - nodes = self._nodes_by_depth[depth] - for node in nodes: - # This is always a single layer, never a list. - layer = node.outbound_layer - if layer in self._input_layers: - # We've already covered the input layers - # a few lines above. - continue - # Potentially redundant list, - # same size as node.input_tensors. - input_shapes = [] - for j in range(len(node.inbound_layers)): - inbound_layer = node.inbound_layers[j] - node_index = node.node_indices[j] - tensor_index = node.tensor_indices[j] - shape_key = inbound_layer.name + '_%s_%s' % (node_index, - tensor_index) - input_shape = layers_to_output_shapes[shape_key] - input_shapes.append(input_shape) - - if len(input_shapes) == 1: - output_shape = layer._compute_output_shape(input_shapes[0]) # pylint: disable=protected-access - else: - output_shape = layer._compute_output_shape(input_shapes) # pylint: disable=protected-access - if isinstance(output_shape, list): - output_shapes = [ - tuple(tensor_shape.TensorShape(shape).as_list()) - for shape in output_shape - ] - else: - output_shapes = [ - tuple(tensor_shape.TensorShape(output_shape).as_list()) - ] - - node_index = layer._inbound_nodes.index(node) # pylint: disable=protected-access - for j in range(len(output_shapes)): - shape_key = layer.name + '_%s_%s' % (node_index, j) - layers_to_output_shapes[shape_key] = output_shapes[j] - - # Read final output shapes from layers_to_output_shapes. - output_shapes = [] - for i in range(len(self._output_layers)): - layer, node_index, tensor_index = self._output_coordinates[i] - shape_key = layer.name + '_%s_%s' % (node_index, tensor_index) - output_shapes.append(layers_to_output_shapes[shape_key]) - - # Store in cache. - self._output_shape_cache[cache_key] = output_shapes - else: - # Cache hit. - output_shapes = self._output_shape_cache[cache_key] - - if isinstance(output_shapes, list): - if len(output_shapes) == 1: - return tensor_shape.TensorShape(output_shapes[0]) - else: - return [tensor_shape.TensorShape(shape) for shape in output_shapes] - else: - return tensor_shape.TensorShape(output_shapes) - - def _run_internal_graph(self, inputs, masks=None): - """Computes output tensors for new inputs. - - # Note: - - Expects `inputs` to be a list (potentially with 1 element). - - Can be run on non-Keras tensors. - - Arguments: - inputs: List of tensors - masks: List of masks (tensors or None). - - Returns: - Three lists: output_tensors, output_masks, output_shapes - """ - # Note: masking support is relevant mainly for Keras. - # It cannot be factored out without having the fully reimplement the - # network calling logic on the Keras side. We choose to incorporate it - # in Network because 1) it may be useful to fully support in tf.layers in - # the future and 2) Keras is a major user of Network. - # If you don't use masking, it does not interfere with regular behavior - # at all and you can ignore it. - if masks is None: - masks = [None for _ in range(len(inputs))] - - # Dictionary mapping reference tensors to tuples - # (computed tensor, compute mask) - # we assume a 1:1 mapping from tensor to mask - # TODO(fchollet): raise exception when a `.compute_mask()` call - # does not return a list the same size as `call` - tensor_map = {} - for x, y, mask in zip(self.inputs, inputs, masks): - tensor_map[str(id(x))] = (y, mask) - - depth_keys = list(self._nodes_by_depth.keys()) - depth_keys.sort(reverse=True) - for depth in depth_keys: - nodes = self._nodes_by_depth[depth] - for node in nodes: - # This is always a single layer, never a list. - layer = node.outbound_layer - - reference_input_tensors = node.input_tensors - reference_output_tensors = node.output_tensors - - # If all previous input tensors are available in tensor_map, - # then call node.inbound_layer on them. - computed_data = [] # List of tuples (input, mask). - for x in reference_input_tensors: - if str(id(x)) in tensor_map: - computed_data.append(tensor_map[str(id(x))]) - - if len(computed_data) == len(reference_input_tensors): - # Call layer (reapplying ops to new inputs). - with ops.name_scope(layer.name): - if node.arguments: - kwargs = node.arguments - else: - kwargs = {} - if len(computed_data) == 1: - computed_tensor, computed_mask = computed_data[0] - # Ensure mask propagation if applicable. - if 'mask' in estimator_util.fn_args(layer.call): - if 'mask' not in kwargs: - kwargs['mask'] = computed_mask - - output_tensors = nest.flatten( - layer.call(computed_tensor, **kwargs)) - if hasattr(layer, 'compute_mask'): - output_masks = nest.flatten( - layer.compute_mask(computed_tensor, computed_mask)) - else: - output_masks = [None for _ in range(len(output_tensors))] - computed_tensors = [computed_tensor] - computed_masks = [computed_mask] - else: - computed_tensors = [x[0] for x in computed_data] - computed_masks = [x[1] for x in computed_data] - if 'mask' in estimator_util.fn_args(layer.call): - if 'mask' not in kwargs: - kwargs['mask'] = computed_masks - output_tensors = nest.flatten( - layer.call(computed_tensors, **kwargs)) - if hasattr(layer, 'compute_mask'): - output_masks = nest.flatten( - layer.compute_mask(computed_tensors, computed_masks)) - else: - output_masks = [None for _ in range(len(output_tensors))] - - # Apply activity regularizer if any: - if layer.activity_regularizer is not None: - regularization_losses = [ - layer.activity_regularizer(x) for x in computed_tensors - ] - layer.add_loss(regularization_losses, computed_tensors) - - if context.in_graph_mode(): - # Update model updates and losses: - # Keep track of updates that depend on the inputs - # (e.g. BN updates). - self.add_update(layer.get_updates_for(computed_tensors), inputs) - # Keep track of unconditional updates (e.g. a counter). - self.add_update(layer.get_updates_for(None), None) - # Keep track of losses that depend on the inputs - # (e.g. activity regularizers). - self.add_loss(layer.get_losses_for(computed_tensors), inputs) - # Keep track of unconditional losses - # (e.g. weight regularizers). - self.add_loss(layer.get_losses_for(None), None) - - # Update tensor_map. - for x, y, mask in zip(reference_output_tensors, output_tensors, - output_masks): - tensor_map[str(id(x))] = (y, mask) - - output_tensors = [] - output_masks = [] - output_shapes = [] - for x in self.outputs: - assert str(id(x)) in tensor_map, 'Could not compute output ' + str(x) - tensor, mask = tensor_map[str(id(x))] - output_shapes.append(_static_shape(x)) - output_tensors.append(tensor) - output_masks.append(mask) - - if len(output_tensors) == 1: - output_tensors = output_tensors[0] - if output_shapes is not None: - output_shapes = output_shapes[0] - if output_masks is not None: - output_masks = output_masks[0] - - if context.in_graph_mode(): - # Update cache; - # keys are based on ids on input tensors and inputs masks. - cache_key = _object_list_uid(inputs) + '_' + _object_list_uid(masks) - self._output_tensor_cache[cache_key] = output_tensors - if output_masks is not None: - self._output_mask_cache[cache_key] = output_masks - if output_shapes is not None: - input_shapes = [_static_shape(x) for x in inputs] - cache_key = _object_list_uid(input_shapes) - self._output_shape_cache[cache_key] = output_shapes - - return output_tensors, output_masks - - def _is_tensor_or_tensor_list(v): v = nest.flatten(v) if v and isinstance(v[0], ops.Tensor): @@ -2296,24 +1382,6 @@ def _add_elements_to_collection(elements, collection_list): collection.append(element) -def _object_list_uid(object_list): - object_list = nest.flatten(object_list) - return ', '.join([str(abs(id(x))) for x in object_list]) - - -def _make_node_key(layer_name, node_index): - return layer_name + '_ib-' + str(node_index) - - -def _static_shape(x): - if x is None: - return None - try: - return tuple(x.get_shape().as_list()) - except ValueError: - return None - - def _is_all_none(iterable_or_element): if not isinstance(iterable_or_element, (list, tuple)): iterable = [iterable_or_element] @@ -2371,7 +1439,8 @@ def _get_default_graph_uid_map(): return name_uid_map -def _unique_layer_name(name, name_uid_map=None, avoid_names=None, namespace=''): +def _unique_layer_name(name, name_uid_map=None, avoid_names=None, namespace='', + zero_based=False): """Makes a layer name (or arbitrary string) unique within a TensorFlow graph. Arguments: @@ -2383,6 +1452,8 @@ def _unique_layer_name(name, name_uid_map=None, avoid_names=None, namespace=''): namespace: Gets a name which is unique within the (graph, namespace). Layers which are not Networks use a blank namespace and so get graph-global names. + zero_based: If True, name sequences start with no suffix (e.g. "dense", + "dense_1"). If False, naming is one-based ("dense_1", "dense_2"). Returns: Unique string name. @@ -2401,6 +1472,14 @@ def _unique_layer_name(name, name_uid_map=None, avoid_names=None, namespace=''): proposed_name = None while proposed_name is None or proposed_name in avoid_names: name_key = (namespace, name) - name_uid_map[name_key] += 1 - proposed_name = name + '_' + str(name_uid_map[name_key]) + if zero_based: + number = name_uid_map[name_key] + if number: + proposed_name = name + '_' + str(number) + else: + proposed_name = name + name_uid_map[name_key] += 1 + else: + name_uid_map[name_key] += 1 + proposed_name = name + '_' + str(name_uid_map[name_key]) return proposed_name diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index 509ad5a7afb..1eea20deefe 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -20,8 +20,6 @@ from __future__ import print_function import copy -import numpy as np - from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -33,7 +31,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test @@ -431,115 +428,6 @@ class BaseLayerTest(test.TestCase): layer.apply(array_ops.placeholder('int32')) layer.apply(array_ops.placeholder('int32', shape=(2, 3))) - def test_get_updates_for(self): - a = base_layers.Input(shape=(2,)) - dense_layer = core_layers.Dense(1) - dense_layer.add_update(0, inputs=a) - dense_layer.add_update(1, inputs=None) - - self.assertEqual(dense_layer.get_updates_for(a), [0]) - self.assertEqual(dense_layer.get_updates_for(None), [1]) - - def test_get_losses_for(self): - a = base_layers.Input(shape=(2,)) - dense_layer = core_layers.Dense(1) - dense_layer.add_loss(0, inputs=a) - dense_layer.add_loss(1, inputs=None) - - self.assertEqual(dense_layer.get_losses_for(a), [0]) - self.assertEqual(dense_layer.get_losses_for(None), [1]) - - def testTopologicalAttributes(self): - # test layer attributes / methods related to cross-layer connectivity. - a = base_layers.Input(shape=(32,), name='input_a') - b = base_layers.Input(shape=(32,), name='input_b') - - # test input, output, input_shape, output_shape - test_layer = core_layers.Dense(16, name='test_layer') - a_test = test_layer(a) - self.assertEqual(test_layer.input, a) - self.assertEqual(test_layer.output, a_test) - self.assertEqual(test_layer.input_shape, (None, 32)) - self.assertEqual(test_layer.output_shape, (None, 16)) - - # test `get_*_at` methods - dense = core_layers.Dense(16, name='dense_1') - a_2 = dense(a) - b_2 = dense(b) - - self.assertEqual(dense.get_input_at(0), a) - self.assertEqual(dense.get_input_at(1), b) - self.assertEqual(dense.get_output_at(0), a_2) - self.assertEqual(dense.get_output_at(1), b_2) - self.assertEqual(dense.get_input_shape_at(0), (None, 32)) - self.assertEqual(dense.get_input_shape_at(1), (None, 32)) - self.assertEqual(dense.get_output_shape_at(0), (None, 16)) - self.assertEqual(dense.get_output_shape_at(1), (None, 16)) - - # Test invalid value for attribute retrieval. - with self.assertRaises(ValueError): - dense.get_input_at(2) - with self.assertRaises(AttributeError): - new_dense = core_layers.Dense(16) - _ = new_dense.input - with self.assertRaises(AttributeError): - new_dense = core_layers.Dense(16) - _ = new_dense.output - with self.assertRaises(AttributeError): - new_dense = core_layers.Dense(16) - _ = new_dense.output_shape - with self.assertRaises(AttributeError): - new_dense = core_layers.Dense(16) - _ = new_dense.input_shape - with self.assertRaises(AttributeError): - new_dense = core_layers.Dense(16) - a = base_layers.Input(shape=(3, 32)) - a = base_layers.Input(shape=(5, 32)) - a_2 = dense(a) - b_2 = dense(b) - _ = new_dense.input_shape - with self.assertRaises(AttributeError): - new_dense = core_layers.Dense(16) - a = base_layers.Input(shape=(3, 32)) - a = base_layers.Input(shape=(5, 32)) - a_2 = dense(a) - b_2 = dense(b) - _ = new_dense.output_shape - - def testTopologicalAttributesMultiOutputLayer(self): - - class PowersLayer(base_layers.Layer): - - def call(self, inputs): - return [inputs**2, inputs**3] - - x = base_layers.Input(shape=(32,)) - test_layer = PowersLayer() - p1, p2 = test_layer(x) # pylint: disable=not-callable - - self.assertEqual(test_layer.input, x) - self.assertEqual(test_layer.output, [p1, p2]) - self.assertEqual(test_layer.input_shape, (None, 32)) - self.assertEqual(test_layer.output_shape, [(None, 32), (None, 32)]) - - def testTopologicalAttributesMultiInputLayer(self): - - class AddLayer(base_layers.Layer): - - def call(self, inputs): - assert len(inputs) == 2 - return inputs[0] + inputs[1] - - a = base_layers.Input(shape=(32,)) - b = base_layers.Input(shape=(32,)) - test_layer = AddLayer() - y = test_layer([a, b]) # pylint: disable=not-callable - - self.assertEqual(test_layer.input, [a, b]) - self.assertEqual(test_layer.output, y) - self.assertEqual(test_layer.input_shape, [(None, 32), (None, 32)]) - self.assertEqual(test_layer.output_shape, (None, 32)) - @test_util.run_in_graph_and_eager_modes() def test_count_params(self): dense = core_layers.Dense(16) @@ -582,383 +470,5 @@ class BaseLayerTest(test.TestCase): self.assertEqual(len(layer.get_losses_for(x)), 1) -class NetworkTest(test.TestCase): - - def testBasicNetwork(self): - # minimum viable network - x = base_layers.Input(shape=(32,)) - dense = core_layers.Dense(2) - y = dense(x) - network = base_layers.Network(x, y, name='dense_network') - - # test basic attributes - self.assertEqual(network.name, 'dense_network') - self.assertEqual(len(network.layers), 2) # InputLayer + Dense - self.assertEqual(network.layers[1], dense) - self.assertEqual(network.weights, dense.weights) - self.assertEqual(network.trainable_weights, dense.trainable_weights) - self.assertEqual(network.non_trainable_weights, dense.non_trainable_weights) - - # test callability on Input - x_2 = base_layers.Input(shape=(32,)) - y_2 = network(x_2) - self.assertEqual(y_2.get_shape().as_list(), [None, 2]) - - # test callability on regular tensor - x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32)) - y_2 = network(x_2) - self.assertEqual(y_2.get_shape().as_list(), [None, 2]) - - # test network `trainable` attribute - network.trainable = False - self.assertEqual(network.weights, dense.weights) - self.assertEqual(network.trainable_weights, []) - self.assertEqual(network.non_trainable_weights, - dense.trainable_weights + dense.non_trainable_weights) - - def test_node_construction(self): - # test graph topology construction basics - a = base_layers.Input(shape=(32,), name='input_a') - b = base_layers.Input(shape=(32,), name='input_b') - - self.assertEqual(a.get_shape().as_list(), [None, 32]) - a_layer, a_node_index, a_tensor_index = a._keras_history - b_layer, _, _ = b._keras_history - self.assertEqual(len(a_layer._inbound_nodes), 1) - self.assertEqual(a_tensor_index, 0) - node = a_layer._inbound_nodes[a_node_index] - self.assertEqual(node.outbound_layer, a_layer) - - self.assertEqual(node.inbound_layers, []) - self.assertEqual(node.input_tensors, [a]) - self.assertEqual(node.input_shapes, [(None, 32)]) - self.assertEqual(node.output_tensors, [a]) - self.assertEqual(node.output_shapes, [(None, 32)]) - - dense = core_layers.Dense(16, name='dense_1') - dense(a) - dense(b) - - self.assertEqual(len(dense._inbound_nodes), 2) - self.assertEqual(len(dense._outbound_nodes), 0) - self.assertEqual(dense._inbound_nodes[0].inbound_layers, [a_layer]) - self.assertEqual(dense._inbound_nodes[0].outbound_layer, dense) - self.assertEqual(dense._inbound_nodes[1].inbound_layers, [b_layer]) - self.assertEqual(dense._inbound_nodes[1].outbound_layer, dense) - self.assertEqual(dense._inbound_nodes[0].input_tensors, [a]) - self.assertEqual(dense._inbound_nodes[1].input_tensors, [b]) - - # Test config - config_0 = dense._inbound_nodes[0].get_config() - self.assertEqual(config_0['outbound_layer'], dense.name) - - def testMultiInputNetwork(self): - a = base_layers.Input(shape=(32,), name='input_a') - b = base_layers.Input(shape=(32,), name='input_b') - - class AddLayer(base_layers.Layer): - - def call(self, inputs): - assert len(inputs) == 2 - return inputs[0] + inputs[1] - - c = AddLayer()([a, b]) # pylint: disable=not-callable - network = base_layers.Network([a, b], c) - self.assertEqual(len(network.layers), 3) # 2 * InputLayer + AddLayer - - # Test callability. - a2 = base_layers.Input(shape=(32,)) - b2 = base_layers.Input(shape=(32,)) - c2 = network([a2, b2]) - self.assertEqual(c2.get_shape().as_list(), [None, 32]) - - def testMultiOutputNetwork(self): - x = base_layers.Input(shape=(32,)) - y1 = core_layers.Dense(2)(x) - y2 = core_layers.Dense(3)(x) - network = base_layers.Network(x, [y1, y2]) - - self.assertEqual(len(network.layers), 3) # InputLayer + 2 * Dense - - # Test callability. - x2 = base_layers.Input(shape=(32,)) - outputs = network(x2) - - self.assertEqual(type(outputs), list) - self.assertEqual(len(outputs), 2) - self.assertEqual(outputs[0].get_shape().as_list(), [None, 2]) - self.assertEqual(outputs[1].get_shape().as_list(), [None, 3]) - - def testMultiInputMultiOutputNetworkSharedLayer(self): - a = base_layers.Input(shape=(32,), name='input_a') - b = base_layers.Input(shape=(32,), name='input_b') - - dense = core_layers.Dense(2) - - y1 = dense(a) - y2 = dense(b) - network = base_layers.Network([a, b], [y1, y2]) - self.assertEqual(len(network.layers), 3) # 2 * InputLayer + Dense - - # Test callability. - a2 = base_layers.Input(shape=(32,)) - b2 = base_layers.Input(shape=(32,)) - outputs = network([a2, b2]) - - self.assertEqual(type(outputs), list) - self.assertEqual(len(outputs), 2) - self.assertEqual(outputs[0].get_shape().as_list(), [None, 2]) - self.assertEqual(outputs[1].get_shape().as_list(), [None, 2]) - - def testCrossDataFlows(self): - # Test the ability to have multi-output layers with outputs that get routed - # to separate layers - - class PowersLayer(base_layers.Layer): - - def call(self, inputs): - return [inputs**2, inputs**3] - - x = base_layers.Input(shape=(32,)) - p1, p2 = PowersLayer()(x) # pylint: disable=not-callable - y1 = core_layers.Dense(2)(p1) - y2 = core_layers.Dense(3)(p2) - network = base_layers.Network(x, [y1, y2]) - - self.assertEqual(len(network.layers), 4) # InputLayer + 2 * Dense + PLayer - - # Test callability. - x2 = base_layers.Input(shape=(32,)) - outputs = network(x2) - - self.assertEqual(type(outputs), list) - self.assertEqual(len(outputs), 2) - self.assertEqual(outputs[0].get_shape().as_list(), [None, 2]) - self.assertEqual(outputs[1].get_shape().as_list(), [None, 3]) - - def testNetworkAttributes(self): - x = base_layers.Input(shape=(32,)) - z = core_layers.Dense(2, kernel_regularizer=lambda x: 0.01 * (x**2))(x) - dense = core_layers.Dense(2, name='dense') - dense.add_update(1) - y = dense(z) - net = base_layers.Network(x, y) - - # losses - self.assertEqual(len(net.losses), 1) - - # updates - self.assertEqual(len(net.updates), 1) - - # get_layer - self.assertEqual(net.get_layer('dense'), dense) - self.assertEqual(net.get_layer(index=2), dense) - with self.assertRaises(ValueError): - net.get_layer('dense_unknown') - with self.assertRaises(ValueError): - net.get_layer() - with self.assertRaises(ValueError): - net.get_layer(index=4) - - # input, output - self.assertEqual(net.input, x) - self.assertEqual(net.output, y) - - # input_shape, output_shape - self.assertEqual(net.input_shape, (None, 32)) - self.assertEqual(net.output_shape, (None, 2)) - - # get_*_at - self.assertEqual(net.get_input_at(0), x) - self.assertEqual(net.get_output_at(0), y) - - # _compute_output_shape - self.assertEqual(net._compute_output_shape((3, 32)).as_list(), [3, 2]) - - def testInvalidNetworks(self): - # redundant inputs - x = base_layers.Input(shape=(32,)) - y = core_layers.Dense(2)(x) - with self.assertRaises(ValueError): - base_layers.Network([x, x], y) - - # inputs that don't come from Input - x = array_ops.placeholder(dtype='float32', shape=(None, 32)) - y = core_layers.Dense(2)(x) - with self.assertRaises(ValueError): - base_layers.Network(x, y) - - # inputs that don't come from Input but have a layer history - x = base_layers.Input(shape=(32,)) - x = core_layers.Dense(32)(x) - y = core_layers.Dense(2)(x) - with self.assertRaises(ValueError): - base_layers.Network(x, y) - - # outputs that don't come from layers - x = base_layers.Input(shape=(32,)) - y = core_layers.Dense(2)(x) - y = 2 * y - with self.assertRaises(ValueError): - base_layers.Network(x, y) - - # disconnected graphs - x1 = base_layers.Input(shape=(32,)) - x2 = base_layers.Input(shape=(32,)) - y = core_layers.Dense(2)(x1) - with self.assertRaises(ValueError): - base_layers.Network(x2, y) - - # redundant layer names - x = base_layers.Input(shape=(32,)) - z = core_layers.Dense(2, name='dense')(x) - y = core_layers.Dense(2, name='dense')(z) - with self.assertRaises(ValueError): - base_layers.Network(x, y) - - def testInputTensorWrapping(self): - x = array_ops.placeholder(dtype='float32', shape=(None, 32)) - x = base_layers.Input(tensor=x) - y = core_layers.Dense(2)(x) - base_layers.Network(x, y) - - def testExplicitBatchSize(self): - x = base_layers.Input(shape=(32,), batch_size=3) - y = core_layers.Dense(2)(x) - self.assertEqual(y.get_shape().as_list(), [3, 2]) - - def testNetworkRecursion(self): - # test the ability of networks to be used as layers inside networks. - a = base_layers.Input(shape=(32,)) - b = core_layers.Dense(2)(a) - net = base_layers.Network(a, b) - - c = base_layers.Input(shape=(32,)) - d = net(c) - - recursive_net = base_layers.Network(c, d) - self.assertEqual(len(recursive_net.layers), 2) - self.assertEqual(recursive_net.layers[1], net) - self.assertEqual(len(recursive_net.weights), 2) - - # test callability - x = array_ops.placeholder(dtype='float32', shape=(None, 32)) - y = recursive_net(x) - self.assertEqual(y.get_shape().as_list(), [None, 2]) - - def testSparseInput(self): - - class SparseSoftmax(base_layers.Layer): - - def call(self, inputs): - return sparse_ops.sparse_softmax(inputs) - - x = base_layers.Input(shape=(32,), sparse=True) - y = SparseSoftmax()(x) # pylint: disable=not-callable - network = base_layers.Network(x, y) - - self.assertEqual(len(network.layers), 2) - self.assertEqual(network.layers[0].sparse, True) - - @test_util.run_in_graph_and_eager_modes() - def testMaskingSingleInput(self): - - class MaskedLayer(base_layers.Layer): - - def call(self, inputs, mask=None): - if mask is not None: - return inputs * mask - return inputs - - def compute_mask(self, inputs, mask=None): - return array_ops.ones_like(inputs) - - if context.in_graph_mode(): - x = base_layers.Input(shape=(32,)) - y = MaskedLayer()(x) # pylint: disable=not-callable - network = base_layers.Network(x, y) - - # test callability on Input - x_2 = base_layers.Input(shape=(32,)) - y_2 = network(x_2) - self.assertEqual(y_2.get_shape().as_list(), [None, 32]) - - # test callability on regular tensor - x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32)) - y_2 = network(x_2) - self.assertEqual(y_2.get_shape().as_list(), [None, 32]) - else: - a = constant_op.constant([2] * 32) - mask = constant_op.constant([0, 1] * 16) - a._keras_mask = mask - b = MaskedLayer().apply(a) - self.assertTrue(hasattr(b, '_keras_mask')) - self.assertAllEqual(self.evaluate(array_ops.ones_like(mask)), - self.evaluate(getattr(b, '_keras_mask'))) - self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b)) - - -class DeferredModeTest(test.TestCase): - - def testDeferredTensorAttributes(self): - x = base_layers._DeferredTensor(shape=(None, 2), dtype='float32', name='x') - self.assertEqual(str(x), - 'DeferredTensor(\'x\', shape=(?, 2), dtype=float32)') - self.assertEqual(repr(x), - '<_DeferredTensor \'x\' shape=(?, 2) dtype=float32>') - - @test_util.run_in_graph_and_eager_modes() - def testSimpleNetworkBuilding(self): - inputs = base_layers.Input(shape=(32,)) - if context.in_eager_mode(): - self.assertIsInstance(inputs, base_layers._DeferredTensor) - self.assertEqual(inputs.dtype.name, 'float32') - self.assertEqual(inputs.shape.as_list(), [None, 32]) - - x = core_layers.Dense(2)(inputs) - if context.in_eager_mode(): - self.assertIsInstance(x, base_layers._DeferredTensor) - self.assertEqual(x.dtype.name, 'float32') - self.assertEqual(x.shape.as_list(), [None, 2]) - - outputs = core_layers.Dense(4)(x) - network = base_layers.Network(inputs, outputs) - self.assertIsInstance(network, base_layers.Network) - - if context.in_eager_mode(): - # It should be possible to call such a network on EagerTensors. - inputs = constant_op.constant( - np.random.random((10, 32)).astype('float32')) - outputs = network(inputs) - self.assertEqual(outputs.shape.as_list(), [10, 4]) - - @test_util.run_in_graph_and_eager_modes() - def testMultiIONetworkbuilding(self): - input_a = base_layers.Input(shape=(32,)) - input_b = base_layers.Input(shape=(16,)) - a = core_layers.Dense(16)(input_a) - - class AddLayer(base_layers.Layer): - - def call(self, inputs): - return inputs[0] + inputs[1] - - def _compute_output_shape(self, input_shape): - return input_shape[0] - - c = AddLayer()([a, input_b]) # pylint: disable=not-callable - c = core_layers.Dense(2)(c) - - network = base_layers.Network([input_a, input_b], [a, c]) - if context.in_eager_mode(): - a_val = constant_op.constant( - np.random.random((10, 32)).astype('float32')) - b_val = constant_op.constant( - np.random.random((10, 16)).astype('float32')) - outputs = network([a_val, b_val]) - self.assertEqual(len(outputs), 2) - self.assertEqual(outputs[0].shape.as_list(), [10, 16]) - self.assertEqual(outputs[1].shape.as_list(), [10, 2]) - if __name__ == '__main__': test.main() diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py index 76e8fbef2f4..7be1fa5cfe9 100644 --- a/tensorflow/python/layers/core.py +++ b/tensorflow/python/layers/core.py @@ -286,11 +286,19 @@ class Dropout(base.Layer): self.noise_shape = noise_shape self.seed = seed - def _get_noise_shape(self, _): + def _get_noise_shape(self, inputs): # Subclasses of `Dropout` may implement `_get_noise_shape(self, inputs)`, # which will override `self.noise_shape`, and allows for custom noise # shapes with dynamically sized inputs. - return self.noise_shape + if self.noise_shape is None: + return self.noise_shape + + symbolic_shape = array_ops.shape(inputs) + noise_shape = [ + symbolic_shape[axis] if shape is None else shape + for axis, shape in enumerate(self.noise_shape) + ] + return noise_shape def call(self, inputs, training=False): diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py index b67df89f81f..2d47cc69798 100644 --- a/tensorflow/python/layers/core_test.py +++ b/tensorflow/python/layers/core_test.py @@ -387,6 +387,16 @@ class DropoutTest(test.TestCase): self.assertAllClose(np.ones((5, 5)), np_output) @test_util.run_in_graph_and_eager_modes() + def testDynamicNoiseShape(self): + inputs = array_ops.ones((5, 3, 2)) + noise_shape = [None, 1, None] + dp = core_layers.Dropout(0.5, noise_shape=noise_shape, seed=1) + dropped = dp.apply(inputs, training=True) + self.evaluate(variables.global_variables_initializer()) + np_output = self.evaluate(dropped) + self.assertAlmostEqual(0., np_output.min()) + self.assertAllClose(np_output[:, 0, :], np_output[:, 1, :]) + def testCustomNoiseShape(self): inputs = array_ops.ones((5, 3, 2)) noise_shape = [5, 1, 2] diff --git a/tensorflow/python/layers/layers.py b/tensorflow/python/layers/layers.py index d3f532e79c1..0a52b1e8d92 100644 --- a/tensorflow/python/layers/layers.py +++ b/tensorflow/python/layers/layers.py @@ -65,8 +65,8 @@ from tensorflow.python.util.all_util import remove_undocumented # Base objects. from tensorflow.python.layers.base import Layer -from tensorflow.python.layers.base import Input from tensorflow.python.layers.base import InputSpec +from tensorflow.python.layers.network import Input # Core layers. from tensorflow.python.layers.core import Dense diff --git a/tensorflow/python/layers/network.py b/tensorflow/python/layers/network.py new file mode 100644 index 00000000000..9a33a5c7269 --- /dev/null +++ b/tensorflow/python/layers/network.py @@ -0,0 +1,957 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Contains Network, a composition of layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.python.eager import context +from tensorflow.python.estimator import util as estimator_util +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import base +from tensorflow.python.layers import utils as layers_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import nest + + +class InputLayer(base.Layer): + """Layer to be used as an entry point into a Network (a graph of layers). + + It can either wrap an existing tensor (pass an `input_tensor` argument) + or create its a placeholder tensor (pass arguments `input_shape` + as well as `dtype`). + + It is generally recommend to use the functional layer API via `Input`, + (which creates an `InputLayer`) without directly using `InputLayer`. + + Arguments: + input_shape: Shape tuple (not including the batch axis), or `TensorShape` + instance (not including the batch axis). + batch_size: Optional input batch size (integer or None). + dtype: Datatype of the input. + input_tensor: Optional tensor to use as layer input + instead of creating a placeholder. + sparse: Boolean, whether the placeholder created + is meant to be sparse. + name: Name of the layer (string). + + Raises: + RuntimeError: If created in Eager mode. + """ + + def __init__(self, + input_shape=None, + batch_size=None, + dtype=dtypes.float32, + input_tensor=None, + sparse=False, + name=None): + super(InputLayer, self).__init__(dtype=dtype, name=name) + self.built = True + self.sparse = sparse + self.batch_size = batch_size + + if isinstance(input_shape, tensor_shape.TensorShape): + input_shape = tuple(input_shape.as_list()) + + if input_tensor is None: + if input_shape is not None: + batch_input_shape = (batch_size,) + tuple(input_shape) + else: + batch_input_shape = None + + if context.in_eager_mode(): + # In eager mode, create a temporary placeholder to call the layer on. + input_tensor = base._DeferredTensor( # pylint: disable=protected-access + shape=batch_input_shape, + dtype=dtype, + name=self.name) + else: + # In graph mode, create a graph placeholder to call the layer on. + if sparse: + input_tensor = array_ops.sparse_placeholder( + shape=batch_input_shape, + dtype=dtype, + name=self.name) + else: + input_tensor = array_ops.placeholder( + shape=batch_input_shape, + dtype=dtype, + name=self.name) + + # For compatibility with Keras API. + self.is_placeholder = True + self._batch_input_shape = batch_input_shape + else: + # For compatibility with Keras API. + self.is_placeholder = False + self._batch_input_shape = tuple(input_tensor.get_shape().as_list()) + + # Create an input node to add to self.outbound_node + # and set output_tensors' _keras_history. + input_tensor._keras_history = (self, 0, 0) # pylint: disable=protected-access + base.Node( + self, + inbound_layers=[], + node_indices=[], + tensor_indices=[], + input_tensors=[input_tensor], + output_tensors=[input_tensor]) + + +def Input( # pylint: disable=invalid-name + shape=None, + batch_size=None, + name=None, + dtype=dtypes.float32, + sparse=False, + tensor=None): + """`Input()` is used to instantiate an input tensor for use with a `Network`. + + For instance, if a, b and c are tensors created via `Input`, + it becomes possible to do: + + `network = Network(inputs=[a, b], outputs=c)` + + Example: + + ```python + # This is a logistic regression + x = tf.layers.Input(shape=(32,)) + y = tf.layers.Dense(16, activation='softmax')(x) + network = tf.layers.Network(x, y) + ``` + + Arguments: + shape: A shape tuple (integer), not including the batch size. + For instance, `shape=(32,)` indicates that the expected input + will be batches of 32-dimensional vectors. + batch_size: Optional input batch size (integer or None). + name: An optional name string for the layer. + Should be unique in a model (do not reuse the same name twice). + It will be autogenerated if it isn't provided. + dtype: The data type expected by the input, as a string + (`float32`, `float64`, `int32`...) + sparse: A boolean specifying whether the placeholder + to be created is sparse. + tensor: Optional existing tensor to wrap into the `Input` layer. + If set, the layer will not create a placeholder tensor. + + Returns: + A tensor: either a new placeholder (with history metadata) or + `tensor` (if passed), with added history metadata. + + Raises: + RuntimeError: If called in Eager mode. + """ + input_layer = InputLayer( + input_shape=shape, + batch_size=batch_size, + name=name, + dtype=dtype, + sparse=sparse, + input_tensor=tensor) + # Return tensor including `_keras_history` metadata. + # Note that in this case train_output and test_output are the same pointer. + outputs = input_layer._inbound_nodes[0].output_tensors # pylint: disable=protected-access + if len(outputs) == 1: + return outputs[0] + else: + return outputs + + +class GraphNetwork(base.Layer): + """A GraphNetwork is a directed acyclic graph of layers. + + It is the topological form of a "model". + A Model is simply a GraphNetwork with added training/evaluation routines. + + A GraphNetwork instance implements the full Layer API. In particular, a + GraphNetwork can be called on new inputs. + + Example: + + ```python + # This is a logistic regression + x = tf.layers.Input(shape=(32,)) + y = tf.layers.Dense(16, activation='softmax')(x) + network = tf.layers.GraphNetwork(x, y) + + # It is then possible to call the network on compatible inputs: + z = tf.layers.Input(shape=(32,)) + w = network(z) + + # It is possible to retrieve the same properties as a layer: + weights = network.trainable_weights + ``` + + Arguments: + inputs: Input tensor or list of input tensors. + Must come from `tf.layers.Input`. + output: Output tensor or list of output tensors. Must come from + tf.layers Layers or Keras layers. + name: Optional name of the model (string). + + Attributes: + GraphNetwork has the same attributes as Layer. On top of it, it also has: + - layers: a list of the children layers of the network, + a list of layer instances, ordered from "earlier in the graph" + to "later in the graph". + + Methods: + GraphNetwork has the same methods as Layer. On top of it, it also has: + - get_layer: retrieves a child layer by name or index in the graph. + + Raises: + RuntimeError: If created in Eager mode. + """ + + def __init__(self, inputs, outputs, name=None): # pylint: disable=super-init-not-called + if context.in_eager_mode(): + # TODO(fchollet): check that all inputs and outputs are DeferredTensors. + pass + + self._init_set_name(name) + self._activity_regularizer = None + with vs.variable_scope( + None, default_name=self._base_name) as captured_scope: + self._scope = captured_scope + call_fn_args = estimator_util.fn_args(self.call) + self._compute_previous_mask = ('mask' in call_fn_args or + hasattr(self, 'compute_mask')) + self._call_has_scope_arg = 'scope' in call_fn_args + + # This acts just like the `trainable` attribute of any layer instance. + # It does not affect users of the underlying layers, only users of the + # GraphNetwork instance. + self.trainable = True + # A GraphNetwork does not create weights of its own, thus it is already + # built. + self.built = True + # A GraphNetwork does not create weights of its own, thus has no dtype. + self._dtype = None + # The following are implemented as property functions: + # self.trainable_weights + # self.non_trainable_weights + # self.input_spec + + # Private attributes to implement compatibility with Layer. + self._per_input_losses = {} + self._per_input_updates = {} + self._updates = [] + self._losses = [] + self._scope = None + self._reuse = None + self._graph = ops.get_default_graph() + + # GraphNetwork-specific properties. + if isinstance(inputs, (list, tuple)): + self.inputs = list(inputs) # Tensor or list of tensors. + else: + self.inputs = [inputs] + if isinstance(outputs, (list, tuple)): + self.outputs = list(outputs) + else: + self.outputs = [outputs] + # All layers in order of horizontal graph traversal. + # Entries are unique. Includes input and output layers. + self.layers = [] + + # Check for redundancy in inputs. + if len(set(self.inputs)) != len(self.inputs): + raise ValueError('The list of inputs passed to the model ' + 'is redundant. ' + 'All inputs should only appear once.' + ' Found: ' + str(self.inputs)) + + # # List of initial layers (1 to 1 mapping with self.inputs, + # # hence the same layer might appear twice) + # self._input_layers = [] + # self._input_layers_node_indices = [] + # self._input_layers_tensor_indices = [] + # # list of layers (1 to 1 mapping with self.inputs, + # # hence the same layer might appear twice) + # self._output_layers = [] + # self._output_layers_node_indices = [] + # self._output_layers_tensor_indices = [] + + self._input_layers = [] + self._output_layers = [] + self._input_coordinates = [] + self._output_coordinates = [] + + # This is for performance optimization when calling the GraphNetwork on new + # inputs. Every time the GraphNetwork is called on a set on input tensors, + # we compute the output tensors, output masks and output shapes in one pass, + # then cache them here. When any of these outputs is queried later, we + # retrieve it from there instead of recomputing it. + self._output_mask_cache = {} + self._output_tensor_cache = {} + self._output_shape_cache = {} + + # User-provided arguments validation. + for x in self.inputs: + # Check that x has appropriate `_keras_history` metadata. + if not hasattr(x, '_keras_history'): + cls_name = self.__class__.__name__ + raise ValueError('Input tensors to a ' + cls_name + ' ' + + 'must come from `tf.layers.Input`. ' + 'Received: ' + str(x) + + ' (missing previous layer metadata).') + # Check that x is an input tensor. + # pylint: disable=protected-access + layer, node_index, tensor_index = x._keras_history + if len(layer._inbound_nodes) > 1 or ( + layer._inbound_nodes and layer._inbound_nodes[0].inbound_layers): + cls_name = self.__class__.__name__ + logging.warning(cls_name + ' inputs must come from ' + '`tf.layers.Input` (thus holding past layer metadata), ' + 'they cannot be the output of ' + 'a previous non-Input layer. ' + 'Here, a tensor specified as ' + 'input to "' + self.name + '" was not an Input tensor, ' + 'it was generated by layer ' + layer.name + '.\n' + 'Note that input tensors are ' + 'instantiated via `tensor = tf.layers.Input(shape)`.\n' + 'The tensor that caused the issue was: ' + str(x.name)) + # pylint: enable=protected-access + for x in self.outputs: + if not hasattr(x, '_keras_history'): + cls_name = self.__class__.__name__ + raise ValueError('Output tensors to a ' + cls_name + ' must be ' + 'the output of a TensorFlow `Layer` ' + '(thus holding past layer metadata). Found: ' + str(x)) + + # Build self._output_layers: + for x in self.outputs: + layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access + self._output_layers.append(layer) + self._output_coordinates.append((layer, node_index, tensor_index)) + + # Build self._input_layers: + for x in self.inputs: + layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access + # It's supposed to be an input layer, so only one node + # and one tensor output. + assert node_index == 0 + assert tensor_index == 0 + self._input_layers.append(layer) + self._input_coordinates.append((layer, node_index, tensor_index)) + + # Network_nodes: set of nodes included in the graph + # (not all nodes included in the layers + # are relevant to the current graph). + network_nodes = set() # ids of all nodes relevant to the GraphNetwork + nodes_depths = {} # dict {node: depth value} + layers_depths = {} # dict {layer: depth value} + layer_indices = {} # dict {layer: index in traversal} + nodes_in_decreasing_depth = [] + + def build_map_of_graph(tensor, + finished_nodes, + nodes_in_progress, + layer, + node_index, + tensor_index): + """Builds a map of the graph of layers. + + This recursively updates the map `layer_indices`, + the list `nodes_in_decreasing_depth` and the set `network_nodes`. + + Arguments: + tensor: Some tensor in a graph. + finished_nodes: Set of nodes whose subgraphs have been traversed + completely. Useful to prevent duplicated work. + nodes_in_progress: Set of nodes that are currently active on the + recursion stack. Useful to detect cycles. + layer: Layer from which `tensor` comes from. If not provided, + will be obtained from `tensor._keras_history`. + node_index: Node index from which `tensor` comes from. + tensor_index: Tensor_index from which `tensor` comes from. + + Raises: + ValueError: if a cycle is detected. + """ + node = layer._inbound_nodes[node_index] # pylint: disable=protected-access + + # Prevent cycles. + if node in nodes_in_progress: + raise ValueError('The tensor ' + str(tensor) + ' at layer "' + + layer.name + '" is part of a cycle.') + + # Don't repeat work for shared subgraphs + if node in finished_nodes: + return + + node_key = _make_node_key(layer.name, node_index) + # Update network_nodes. + network_nodes.add(node_key) + + # Store the traversal order for layer sorting. + if layer not in layer_indices: + layer_indices[layer] = len(layer_indices) + + nodes_in_progress.add(node) + + # Propagate to all previous tensors connected to this node. + for i in range(len(node.inbound_layers)): + x = node.input_tensors[i] + layer = node.inbound_layers[i] + node_index = node.node_indices[i] + tensor_index = node.tensor_indices[i] + build_map_of_graph(x, finished_nodes, nodes_in_progress, layer, + node_index, tensor_index) + + finished_nodes.add(node) + nodes_in_progress.remove(node) + nodes_in_decreasing_depth.append(node) + + finished_nodes = set() + nodes_in_progress = set() + for x in self.outputs: + layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access + build_map_of_graph(x, finished_nodes, nodes_in_progress, + layer=layer, + node_index=node_index, + tensor_index=tensor_index) + + for node in reversed(nodes_in_decreasing_depth): + # If the depth is not set, the node has no outbound nodes (depth 0). + depth = nodes_depths.setdefault(node, 0) + + # Update the depth of the corresponding layer + previous_depth = layers_depths.get(node.outbound_layer, 0) + # If we've seen this layer before at a higher depth, + # we should use that depth instead of the node depth. + # This is necessary for shared layers that have inputs at different + # depth levels in the graph. + depth = max(depth, previous_depth) + layers_depths[node.outbound_layer] = depth + nodes_depths[node] = depth + + # Update the depth of inbound nodes. + # The "depth" of a node is the max of the depths + # of all layers it is connected to. + for i in range(len(node.inbound_layers)): + inbound_layer = node.inbound_layers[i] + node_index = node.node_indices[i] + inbound_node = inbound_layer._inbound_nodes[node_index] # pylint: disable=protected-access + previous_depth = nodes_depths.get(inbound_node, 0) + nodes_depths[inbound_node] = max(depth + 1, previous_depth) + + # Build a dict {depth: list of nodes with this depth} + nodes_by_depth = {} + for node, depth in nodes_depths.items(): + if depth not in nodes_by_depth: + nodes_by_depth[depth] = [] + nodes_by_depth[depth].append(node) + + # Build a dict {depth: list of layers with this depth} + layers_by_depth = {} + for layer, depth in layers_depths.items(): + if depth not in layers_by_depth: + layers_by_depth[depth] = [] + layers_by_depth[depth].append(layer) + + # Get sorted list of layer depths. + depth_keys = list(layers_by_depth.keys()) + depth_keys.sort(reverse=True) + + # Set self.layers and self._layers_by_depth. + layers = [] + for depth in depth_keys: + layers_for_depth = layers_by_depth[depth] + # GraphNetwork.layers needs to have a deterministic order: + # here we order them by traversal order. + layers_for_depth.sort(key=lambda x: layer_indices[x]) + layers.extend(layers_for_depth) + self.layers = layers + self._layers_by_depth = layers_by_depth + + # Get sorted list of node depths. + depth_keys = list(nodes_by_depth.keys()) + depth_keys.sort(reverse=True) + + # Check that all tensors required are computable. + # computable_tensors: all tensors in the graph + # that can be computed from the inputs provided. + computable_tensors = [] + for x in self.inputs: + computable_tensors.append(x) + + layers_with_complete_input = [] # To provide a better error msg. + for depth in depth_keys: + for node in nodes_by_depth[depth]: + layer = node.outbound_layer + if layer: + for x in node.input_tensors: + if x not in computable_tensors: + raise ValueError('Graph disconnected: ' + 'cannot obtain value for tensor ' + str(x) + + ' at layer "' + layer.name + '". ' + 'The following previous layers ' + 'were accessed without issue: ' + + str(layers_with_complete_input)) + for x in node.output_tensors: + computable_tensors.append(x) + layers_with_complete_input.append(layer.name) + + # Keep track of the network's nodes. + self._network_nodes = network_nodes + self._nodes_by_depth = nodes_by_depth + + # Ensure name unicity, which will be crucial for serialization + # (since serialized nodes refer to layers by their name). + all_names = [layer.name for layer in self.layers] + for name in all_names: + if all_names.count(name) != 1: + raise ValueError('The name "' + name + '" is used ' + + str(all_names.count(name)) + ' times in the model. ' + 'All layer names should be unique.') + + # Layer parameters. + # The new network starts with a single inbound node + # for its inputs, and no outbound nodes. + self._outbound_nodes = [] # Will be appended to by future calls to __call__ + self._inbound_nodes = [ + ] # Will be appended to below, and by future calls to __call__ + # Create the node linking internal inputs to internal outputs. + base.Node( + outbound_layer=self, + inbound_layers=[], + node_indices=[], + tensor_indices=[], + input_tensors=self.inputs, + output_tensors=self.outputs) + + def get_layer(self, name=None, index=None): + """Retrieves a layer based on either its name (unique) or index. + + Indices are based on order of horizontal graph traversal (bottom-up). + + Arguments: + name: String, name of layer. + index: Integer, index of layer. + + Returns: + A layer instance. + + Raises: + ValueError: In case of invalid layer name or index. + """ + # TODO(fchollet): We could build a dictionary based on layer names + # since they are constant, but we have not done that yet. + if index is not None: + if len(self.layers) <= index: + raise ValueError('Was asked to retrieve layer at index ' + str(index) + + ' but model only has ' + str(len(self.layers)) + + ' layers.') + else: + return self.layers[index] + else: + if not name: + raise ValueError('Provide either a layer name or layer index.') + for layer in self.layers: + if layer.name == name: + return layer + raise ValueError('No such layer: ' + name) + + @property + def updates(self): + """Retrieve the network's updates. + + Will only include updates that are either + unconditional, or conditional on inputs to this model + (e.g. will not include updates that depend on tensors + that aren't inputs to this model). + + Returns: + A list of update ops. + """ + updates = [] + for layer in self.layers: + if hasattr(layer, 'updates'): + # Collect updates that are dependent on inputs + # that are part of the model. + for node_index, node in enumerate(layer._inbound_nodes): # pylint: disable=protected-access + node_key = _make_node_key(layer.name, node_index) + if node_key in self._network_nodes: + # The model owns this layer node. + inputs = node.input_tensors + updates += layer.get_updates_for(inputs) + # Collect unconditional updates. + updates += layer.get_updates_for(None) + return updates + + @property + def losses(self): + """Retrieve the network's losses. + + Will only include losses that are either + unconditional, or conditional on inputs to this model + (e.g. will not include losses that depend on tensors + that aren't inputs to this model). + + Returns: + A list of loss tensors. + """ + losses = [] + # Retrieve losses for all internal layers. + for layer in self.layers: + if hasattr(layer, 'losses'): + # Collect losses that are dependent on inputs + # that are part of the model. + for node_index, node in enumerate(layer._inbound_nodes): # pylint: disable=protected-access + node_key = _make_node_key(layer.name, node_index) + if node_key in self._network_nodes: + # The model owns this layer node. + inputs = node.input_tensors + losses += layer.get_losses_for(inputs) + # Collect unconditional losses. + losses += layer.get_losses_for(None) + # Add any potential unconditional model-level loss. + losses += self.get_losses_for(None) + return losses + + @property + def trainable_weights(self): + if not self.trainable: + return [] + weights = [] + for layer in self.layers: + weights += layer.trainable_weights + return weights + + @property + def non_trainable_weights(self): + weights = [] + for layer in self.layers: + weights += layer.non_trainable_weights + if not self.trainable: + trainable_weights = [] + for layer in self.layers: + trainable_weights += layer.trainable_weights + return trainable_weights + weights + return weights + + @property + def input_spec(self): + """Gets the network's input specs. + + Returns: + A list of `InputSpec` instances (one per input to the model) + or a single instance if the model has only one input. + """ + specs = [] + for layer in self._input_layers: + if layer.input_spec is None: + specs.append(None) + else: + if not isinstance(layer.input_spec, list): + raise TypeError('Layer ' + layer.name + + ' has an input_spec attribute that ' + 'is not a list. We expect a list. ' + 'Found input_spec = ' + str(layer.input_spec)) + specs += layer.input_spec + if len(specs) == 1: + return specs[0] + return specs + + def call(self, inputs, mask=None): + """Call the model on new inputs. + + In this case `call` just reapplies + all ops in the graph to the new inputs + (e.g. build a new computational graph from the provided inputs). + + Arguments: + inputs: A tensor or list of tensors. + mask: A mask or list of masks. A mask can be + either a tensor or None (no mask). + + Returns: + A tensor if there is a single output, or + a list of tensors if there are more than one outputs. + """ + inputs = nest.flatten(inputs) + if mask is None: + masks = [None for _ in range(len(inputs))] + else: + masks = nest.flatten(mask) + + if context.in_graph_mode(): + # Try to retrieve cached outputs if the layer has already been called + # on these exact inputs. + cache_key = (layers_util.object_list_uid(inputs) + + '_' + layers_util.object_list_uid(masks)) + if cache_key in self._output_tensor_cache: + # Cache hit. + return self._output_tensor_cache[cache_key] + # Actually apply the network graph to the new inputs. + outputs, _ = self._run_internal_graph(inputs, masks) + return outputs + + def _compute_output_shape(self, input_shape): + if isinstance(input_shape, list): + input_shapes = [] + for shape in input_shape: + if shape is not None: + input_shapes.append(tuple(tensor_shape.TensorShape(shape).as_list())) + else: + input_shapes.append(None) + else: + if input_shape is not None: + input_shapes = [tuple(tensor_shape.TensorShape(input_shape).as_list())] + else: + input_shapes = [None] + + if len(input_shapes) != len(self._input_layers): + raise ValueError('Invalid input_shape argument ' + str(input_shape) + + ': model has ' + str(len(self._input_layers)) + + ' tensor inputs.') + + cache_key = layers_util.object_list_uid(input_shapes) + if cache_key not in self._output_shape_cache: + # Cache miss. We have to run the network graph manually (recursive calls + # to `_compute_output_shape`). + layers_to_output_shapes = {} + for i in range(len(input_shapes)): + layer = self._input_layers[i] + input_shape = input_shapes[i] + # It's an input layer: then `_compute_output_shape` is identity, + # and there is only one node and one tensor output. + shape_key = layer.name + '_0_0' + layers_to_output_shapes[shape_key] = input_shape + + depth_keys = list(self._nodes_by_depth.keys()) + depth_keys.sort(reverse=True) + # Iterate over nodes, by depth level. + if len(depth_keys) > 1: + for depth in depth_keys: + nodes = self._nodes_by_depth[depth] + for node in nodes: + # This is always a single layer, never a list. + layer = node.outbound_layer + if layer in self._input_layers: + # We've already covered the input layers + # a few lines above. + continue + # Potentially redundant list, + # same size as node.input_tensors. + input_shapes = [] + for j in range(len(node.inbound_layers)): + inbound_layer = node.inbound_layers[j] + node_index = node.node_indices[j] + tensor_index = node.tensor_indices[j] + shape_key = inbound_layer.name + '_%s_%s' % (node_index, + tensor_index) + input_shape = layers_to_output_shapes[shape_key] + input_shapes.append(input_shape) + + if len(input_shapes) == 1: + output_shape = layer._compute_output_shape(input_shapes[0]) # pylint: disable=protected-access + else: + output_shape = layer._compute_output_shape(input_shapes) # pylint: disable=protected-access + if isinstance(output_shape, list): + output_shapes = [ + tuple(tensor_shape.TensorShape(shape).as_list()) + for shape in output_shape + ] + else: + output_shapes = [ + tuple(tensor_shape.TensorShape(output_shape).as_list()) + ] + + node_index = layer._inbound_nodes.index(node) # pylint: disable=protected-access + for j in range(len(output_shapes)): + shape_key = layer.name + '_%s_%s' % (node_index, j) + layers_to_output_shapes[shape_key] = output_shapes[j] + + # Read final output shapes from layers_to_output_shapes. + output_shapes = [] + for i in range(len(self._output_layers)): + layer, node_index, tensor_index = self._output_coordinates[i] + shape_key = layer.name + '_%s_%s' % (node_index, tensor_index) + output_shapes.append(layers_to_output_shapes[shape_key]) + + # Store in cache. + self._output_shape_cache[cache_key] = output_shapes + else: + # Cache hit. + output_shapes = self._output_shape_cache[cache_key] + + if isinstance(output_shapes, list): + if len(output_shapes) == 1: + return tensor_shape.TensorShape(output_shapes[0]) + else: + return [tensor_shape.TensorShape(shape) for shape in output_shapes] + else: + return tensor_shape.TensorShape(output_shapes) + + def _run_internal_graph(self, inputs, masks=None): + """Computes output tensors for new inputs. + + # Note: + - Expects `inputs` to be a list (potentially with 1 element). + - Can be run on non-Keras tensors. + + Arguments: + inputs: List of tensors + masks: List of masks (tensors or None). + + Returns: + Three lists: output_tensors, output_masks, output_shapes + """ + # Note: masking support is relevant mainly for Keras. + # It cannot be factored out without having the fully reimplement the network + # calling logic on the Keras side. We choose to incorporate it in + # GraphNetwork because 1) it may be useful to fully support in tf.layers in + # the future and 2) Keras is a major user of GraphNetwork. If you don't + # use masking, it does not interfere with regular behavior at all and you + # can ignore it. + if masks is None: + masks = [None for _ in range(len(inputs))] + + # Dictionary mapping reference tensors to tuples + # (computed tensor, compute mask) + # we assume a 1:1 mapping from tensor to mask + # TODO(fchollet): raise exception when a `.compute_mask()` call + # does not return a list the same size as `call` + tensor_map = {} + for x, y, mask in zip(self.inputs, inputs, masks): + tensor_map[str(id(x))] = (y, mask) + + depth_keys = list(self._nodes_by_depth.keys()) + depth_keys.sort(reverse=True) + for depth in depth_keys: + nodes = self._nodes_by_depth[depth] + for node in nodes: + # This is always a single layer, never a list. + layer = node.outbound_layer + + reference_input_tensors = node.input_tensors + reference_output_tensors = node.output_tensors + + # If all previous input tensors are available in tensor_map, + # then call node.inbound_layer on them. + computed_data = [] # List of tuples (input, mask). + for x in reference_input_tensors: + if str(id(x)) in tensor_map: + computed_data.append(tensor_map[str(id(x))]) + + if len(computed_data) == len(reference_input_tensors): + # Call layer (reapplying ops to new inputs). + with ops.name_scope(layer.name): + if node.arguments: + kwargs = node.arguments + else: + kwargs = {} + if len(computed_data) == 1: + computed_tensor, computed_mask = computed_data[0] + # Ensure mask propagation if applicable. + if 'mask' in estimator_util.fn_args(layer.call): + if 'mask' not in kwargs: + kwargs['mask'] = computed_mask + + output_tensors = nest.flatten( + layer.call(computed_tensor, **kwargs)) + if hasattr(layer, 'compute_mask'): + output_masks = nest.flatten( + layer.compute_mask(computed_tensor, computed_mask)) + else: + output_masks = [None for _ in range(len(output_tensors))] + computed_tensors = [computed_tensor] + computed_masks = [computed_mask] + else: + computed_tensors = [x[0] for x in computed_data] + computed_masks = [x[1] for x in computed_data] + if 'mask' in estimator_util.fn_args(layer.call): + if 'mask' not in kwargs: + kwargs['mask'] = computed_masks + output_tensors = nest.flatten( + layer.call(computed_tensors, **kwargs)) + if hasattr(layer, 'compute_mask'): + output_masks = nest.flatten( + layer.compute_mask(computed_tensors, computed_masks)) + else: + output_masks = [None for _ in range(len(output_tensors))] + + # Apply activity regularizer if any: + if layer.activity_regularizer is not None: + regularization_losses = [ + layer.activity_regularizer(x) for x in computed_tensors + ] + layer.add_loss(regularization_losses, computed_tensors) + + if context.in_graph_mode(): + # Update model updates and losses: + # Keep track of updates that depend on the inputs + # (e.g. BN updates). + self.add_update(layer.get_updates_for(computed_tensors), inputs) + # Keep track of unconditional updates (e.g. a counter). + self.add_update(layer.get_updates_for(None), None) + # Keep track of losses that depend on the inputs + # (e.g. activity regularizers). + self.add_loss(layer.get_losses_for(computed_tensors), inputs) + # Keep track of unconditional losses + # (e.g. weight regularizers). + self.add_loss(layer.get_losses_for(None), None) + + # Update tensor_map. + for x, y, mask in zip(reference_output_tensors, output_tensors, + output_masks): + tensor_map[str(id(x))] = (y, mask) + + output_tensors = [] + output_masks = [] + output_shapes = [] + for x in self.outputs: + assert str(id(x)) in tensor_map, 'Could not compute output ' + str(x) + tensor, mask = tensor_map[str(id(x))] + output_shapes.append(layers_util.static_shape(x)) + output_tensors.append(tensor) + output_masks.append(mask) + + if len(output_tensors) == 1: + output_tensors = output_tensors[0] + if output_shapes is not None: + output_shapes = output_shapes[0] + if output_masks is not None: + output_masks = output_masks[0] + + if context.in_graph_mode(): + # Update cache; + # keys are based on ids on input tensors and inputs masks. + cache_key = (layers_util.object_list_uid(inputs) + + '_' + layers_util.object_list_uid(masks)) + self._output_tensor_cache[cache_key] = output_tensors + if output_masks is not None: + self._output_mask_cache[cache_key] = output_masks + if output_shapes is not None: + input_shapes = [layers_util.static_shape(x) for x in inputs] + cache_key = layers_util.object_list_uid(input_shapes) + self._output_shape_cache[cache_key] = output_shapes + + return output_tensors, output_masks + + +def _make_node_key(layer_name, node_index): + return layer_name + '_ib-' + str(node_index) diff --git a/tensorflow/python/layers/network_test.py b/tensorflow/python/layers/network_test.py new file mode 100644 index 00000000000..af7813e2642 --- /dev/null +++ b/tensorflow/python/layers/network_test.py @@ -0,0 +1,525 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.layers.network.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import test_util +from tensorflow.python.layers import base as base_layers +from tensorflow.python.layers import core as core_layers +from tensorflow.python.layers import network as network_layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +class BaseLayerCompatibilityTest(test.TestCase): + + def test_get_updates_for(self): + a = network_layers.Input(shape=(2,)) + dense_layer = core_layers.Dense(1) + dense_layer.add_update(0, inputs=a) + dense_layer.add_update(1, inputs=None) + + self.assertEqual(dense_layer.get_updates_for(a), [0]) + self.assertEqual(dense_layer.get_updates_for(None), [1]) + + def test_get_losses_for(self): + a = network_layers.Input(shape=(2,)) + dense_layer = core_layers.Dense(1) + dense_layer.add_loss(0, inputs=a) + dense_layer.add_loss(1, inputs=None) + + self.assertEqual(dense_layer.get_losses_for(a), [0]) + self.assertEqual(dense_layer.get_losses_for(None), [1]) + + def testTopologicalAttributes(self): + # test layer attributes / methods related to cross-layer connectivity. + a = network_layers.Input(shape=(32,), name='input_a') + b = network_layers.Input(shape=(32,), name='input_b') + + # test input, output, input_shape, output_shape + test_layer = core_layers.Dense(16, name='test_layer') + a_test = test_layer(a) + self.assertEqual(test_layer.input, a) + self.assertEqual(test_layer.output, a_test) + self.assertEqual(test_layer.input_shape, (None, 32)) + self.assertEqual(test_layer.output_shape, (None, 16)) + + # test `get_*_at` methods + dense = core_layers.Dense(16, name='dense_1') + a_2 = dense(a) + b_2 = dense(b) + + self.assertEqual(dense.get_input_at(0), a) + self.assertEqual(dense.get_input_at(1), b) + self.assertEqual(dense.get_output_at(0), a_2) + self.assertEqual(dense.get_output_at(1), b_2) + self.assertEqual(dense.get_input_shape_at(0), (None, 32)) + self.assertEqual(dense.get_input_shape_at(1), (None, 32)) + self.assertEqual(dense.get_output_shape_at(0), (None, 16)) + self.assertEqual(dense.get_output_shape_at(1), (None, 16)) + + # Test invalid value for attribute retrieval. + with self.assertRaises(ValueError): + dense.get_input_at(2) + with self.assertRaises(AttributeError): + new_dense = core_layers.Dense(16) + _ = new_dense.input + with self.assertRaises(AttributeError): + new_dense = core_layers.Dense(16) + _ = new_dense.output + with self.assertRaises(AttributeError): + new_dense = core_layers.Dense(16) + _ = new_dense.output_shape + with self.assertRaises(AttributeError): + new_dense = core_layers.Dense(16) + _ = new_dense.input_shape + with self.assertRaises(AttributeError): + new_dense = core_layers.Dense(16) + a = network_layers.Input(shape=(3, 32)) + a = network_layers.Input(shape=(5, 32)) + a_2 = dense(a) + b_2 = dense(b) + _ = new_dense.input_shape + with self.assertRaises(AttributeError): + new_dense = core_layers.Dense(16) + a = network_layers.Input(shape=(3, 32)) + a = network_layers.Input(shape=(5, 32)) + a_2 = dense(a) + b_2 = dense(b) + _ = new_dense.output_shape + + def testTopologicalAttributesMultiOutputLayer(self): + + class PowersLayer(base_layers.Layer): + + def call(self, inputs): + return [inputs**2, inputs**3] + + x = network_layers.Input(shape=(32,)) + test_layer = PowersLayer() + p1, p2 = test_layer(x) # pylint: disable=not-callable + + self.assertEqual(test_layer.input, x) + self.assertEqual(test_layer.output, [p1, p2]) + self.assertEqual(test_layer.input_shape, (None, 32)) + self.assertEqual(test_layer.output_shape, [(None, 32), (None, 32)]) + + def testTopologicalAttributesMultiInputLayer(self): + + class AddLayer(base_layers.Layer): + + def call(self, inputs): + assert len(inputs) == 2 + return inputs[0] + inputs[1] + + a = network_layers.Input(shape=(32,)) + b = network_layers.Input(shape=(32,)) + test_layer = AddLayer() + y = test_layer([a, b]) # pylint: disable=not-callable + + self.assertEqual(test_layer.input, [a, b]) + self.assertEqual(test_layer.output, y) + self.assertEqual(test_layer.input_shape, [(None, 32), (None, 32)]) + self.assertEqual(test_layer.output_shape, (None, 32)) + + +class NetworkTest(test.TestCase): + + def testBasicNetwork(self): + # minimum viable network + x = network_layers.Input(shape=(32,)) + dense = core_layers.Dense(2) + y = dense(x) + network = network_layers.GraphNetwork(x, y, name='dense_network') + + # test basic attributes + self.assertEqual(network.name, 'dense_network') + self.assertEqual(len(network.layers), 2) # InputLayer + Dense + self.assertEqual(network.layers[1], dense) + self.assertEqual(network.weights, dense.weights) + self.assertEqual(network.trainable_weights, dense.trainable_weights) + self.assertEqual(network.non_trainable_weights, dense.non_trainable_weights) + + # test callability on Input + x_2 = network_layers.Input(shape=(32,)) + y_2 = network(x_2) + self.assertEqual(y_2.get_shape().as_list(), [None, 2]) + + # test callability on regular tensor + x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32)) + y_2 = network(x_2) + self.assertEqual(y_2.get_shape().as_list(), [None, 2]) + + # test network `trainable` attribute + network.trainable = False + self.assertEqual(network.weights, dense.weights) + self.assertEqual(network.trainable_weights, []) + self.assertEqual(network.non_trainable_weights, + dense.trainable_weights + dense.non_trainable_weights) + + def test_node_construction(self): + # test graph topology construction basics + a = network_layers.Input(shape=(32,), name='input_a') + b = network_layers.Input(shape=(32,), name='input_b') + + self.assertEqual(a.get_shape().as_list(), [None, 32]) + a_layer, a_node_index, a_tensor_index = a._keras_history + b_layer, _, _ = b._keras_history + self.assertEqual(len(a_layer._inbound_nodes), 1) + self.assertEqual(a_tensor_index, 0) + node = a_layer._inbound_nodes[a_node_index] + self.assertEqual(node.outbound_layer, a_layer) + + self.assertEqual(node.inbound_layers, []) + self.assertEqual(node.input_tensors, [a]) + self.assertEqual(node.input_shapes, [(None, 32)]) + self.assertEqual(node.output_tensors, [a]) + self.assertEqual(node.output_shapes, [(None, 32)]) + + dense = core_layers.Dense(16, name='dense_1') + dense(a) + dense(b) + + self.assertEqual(len(dense._inbound_nodes), 2) + self.assertEqual(len(dense._outbound_nodes), 0) + self.assertEqual(dense._inbound_nodes[0].inbound_layers, [a_layer]) + self.assertEqual(dense._inbound_nodes[0].outbound_layer, dense) + self.assertEqual(dense._inbound_nodes[1].inbound_layers, [b_layer]) + self.assertEqual(dense._inbound_nodes[1].outbound_layer, dense) + self.assertEqual(dense._inbound_nodes[0].input_tensors, [a]) + self.assertEqual(dense._inbound_nodes[1].input_tensors, [b]) + + # Test config + config_0 = dense._inbound_nodes[0].get_config() + self.assertEqual(config_0['outbound_layer'], dense.name) + + def testMultiInputNetwork(self): + a = network_layers.Input(shape=(32,), name='input_a') + b = network_layers.Input(shape=(32,), name='input_b') + + class AddLayer(base_layers.Layer): + + def call(self, inputs): + assert len(inputs) == 2 + return inputs[0] + inputs[1] + + c = AddLayer()([a, b]) # pylint: disable=not-callable + network = network_layers.GraphNetwork([a, b], c) + self.assertEqual(len(network.layers), 3) # 2 * InputLayer + AddLayer + + # Test callability. + a2 = network_layers.Input(shape=(32,)) + b2 = network_layers.Input(shape=(32,)) + c2 = network([a2, b2]) + self.assertEqual(c2.get_shape().as_list(), [None, 32]) + + def testMultiOutputNetwork(self): + x = network_layers.Input(shape=(32,)) + y1 = core_layers.Dense(2)(x) + y2 = core_layers.Dense(3)(x) + network = network_layers.GraphNetwork(x, [y1, y2]) + + self.assertEqual(len(network.layers), 3) # InputLayer + 2 * Dense + + # Test callability. + x2 = network_layers.Input(shape=(32,)) + outputs = network(x2) + + self.assertEqual(type(outputs), list) + self.assertEqual(len(outputs), 2) + self.assertEqual(outputs[0].get_shape().as_list(), [None, 2]) + self.assertEqual(outputs[1].get_shape().as_list(), [None, 3]) + + def testMultiInputMultiOutputNetworkSharedLayer(self): + a = network_layers.Input(shape=(32,), name='input_a') + b = network_layers.Input(shape=(32,), name='input_b') + + dense = core_layers.Dense(2) + + y1 = dense(a) + y2 = dense(b) + network = network_layers.GraphNetwork([a, b], [y1, y2]) + self.assertEqual(len(network.layers), 3) # 2 * InputLayer + Dense + + # Test callability. + a2 = network_layers.Input(shape=(32,)) + b2 = network_layers.Input(shape=(32,)) + outputs = network([a2, b2]) + + self.assertEqual(type(outputs), list) + self.assertEqual(len(outputs), 2) + self.assertEqual(outputs[0].get_shape().as_list(), [None, 2]) + self.assertEqual(outputs[1].get_shape().as_list(), [None, 2]) + + def testCrossDataFlows(self): + # Test the ability to have multi-output layers with outputs that get routed + # to separate layers + + class PowersLayer(base_layers.Layer): + + def call(self, inputs): + return [inputs**2, inputs**3] + + x = network_layers.Input(shape=(32,)) + p1, p2 = PowersLayer()(x) # pylint: disable=not-callable + y1 = core_layers.Dense(2)(p1) + y2 = core_layers.Dense(3)(p2) + network = network_layers.GraphNetwork(x, [y1, y2]) + + self.assertEqual(len(network.layers), 4) # InputLayer + 2 * Dense + PLayer + + # Test callability. + x2 = network_layers.Input(shape=(32,)) + outputs = network(x2) + + self.assertEqual(type(outputs), list) + self.assertEqual(len(outputs), 2) + self.assertEqual(outputs[0].get_shape().as_list(), [None, 2]) + self.assertEqual(outputs[1].get_shape().as_list(), [None, 3]) + + def testNetworkAttributes(self): + x = network_layers.Input(shape=(32,)) + z = core_layers.Dense(2, kernel_regularizer=lambda x: 0.01 * (x**2))(x) + dense = core_layers.Dense(2, name='dense') + dense.add_update(1) + y = dense(z) + net = network_layers.GraphNetwork(x, y) + + # losses + self.assertEqual(len(net.losses), 1) + + # updates + self.assertEqual(len(net.updates), 1) + + # get_layer + self.assertEqual(net.get_layer('dense'), dense) + self.assertEqual(net.get_layer(index=2), dense) + with self.assertRaises(ValueError): + net.get_layer('dense_unknown') + with self.assertRaises(ValueError): + net.get_layer() + with self.assertRaises(ValueError): + net.get_layer(index=4) + + # input, output + self.assertEqual(net.input, x) + self.assertEqual(net.output, y) + + # input_shape, output_shape + self.assertEqual(net.input_shape, (None, 32)) + self.assertEqual(net.output_shape, (None, 2)) + + # get_*_at + self.assertEqual(net.get_input_at(0), x) + self.assertEqual(net.get_output_at(0), y) + + # _compute_output_shape + self.assertEqual(net._compute_output_shape((3, 32)).as_list(), [3, 2]) + + def testInvalidNetworks(self): + # redundant inputs + x = network_layers.Input(shape=(32,)) + y = core_layers.Dense(2)(x) + with self.assertRaises(ValueError): + network_layers.GraphNetwork([x, x], y) + + # inputs that don't come from Input + x = array_ops.placeholder(dtype='float32', shape=(None, 32)) + y = core_layers.Dense(2)(x) + with self.assertRaises(ValueError): + network_layers.GraphNetwork(x, y) + + # inputs that don't come from Input but have a layer history + x = network_layers.Input(shape=(32,)) + x = core_layers.Dense(32)(x) + y = core_layers.Dense(2)(x) + with self.assertRaises(ValueError): + network_layers.GraphNetwork(x, y) + + # outputs that don't come from layers + x = network_layers.Input(shape=(32,)) + y = core_layers.Dense(2)(x) + y = 2 * y + with self.assertRaises(ValueError): + network_layers.GraphNetwork(x, y) + + # disconnected graphs + x1 = network_layers.Input(shape=(32,)) + x2 = network_layers.Input(shape=(32,)) + y = core_layers.Dense(2)(x1) + with self.assertRaises(ValueError): + network_layers.GraphNetwork(x2, y) + + # redundant layer names + x = network_layers.Input(shape=(32,)) + z = core_layers.Dense(2, name='dense')(x) + y = core_layers.Dense(2, name='dense')(z) + with self.assertRaises(ValueError): + network_layers.GraphNetwork(x, y) + + def testInputTensorWrapping(self): + x = array_ops.placeholder(dtype='float32', shape=(None, 32)) + x = network_layers.Input(tensor=x) + y = core_layers.Dense(2)(x) + network_layers.GraphNetwork(x, y) + + def testExplicitBatchSize(self): + x = network_layers.Input(shape=(32,), batch_size=3) + y = core_layers.Dense(2)(x) + self.assertEqual(y.get_shape().as_list(), [3, 2]) + + def testNetworkRecursion(self): + # test the ability of networks to be used as layers inside networks. + a = network_layers.Input(shape=(32,)) + b = core_layers.Dense(2)(a) + net = network_layers.GraphNetwork(a, b) + + c = network_layers.Input(shape=(32,)) + d = net(c) + + recursive_net = network_layers.GraphNetwork(c, d) + self.assertEqual(len(recursive_net.layers), 2) + self.assertEqual(recursive_net.layers[1], net) + self.assertEqual(len(recursive_net.weights), 2) + + # test callability + x = array_ops.placeholder(dtype='float32', shape=(None, 32)) + y = recursive_net(x) + self.assertEqual(y.get_shape().as_list(), [None, 2]) + + def testSparseInput(self): + + class SparseSoftmax(base_layers.Layer): + + def call(self, inputs): + return sparse_ops.sparse_softmax(inputs) + + x = network_layers.Input(shape=(32,), sparse=True) + y = SparseSoftmax()(x) # pylint: disable=not-callable + network = network_layers.GraphNetwork(x, y) + + self.assertEqual(len(network.layers), 2) + self.assertEqual(network.layers[0].sparse, True) + + @test_util.run_in_graph_and_eager_modes() + def testMaskingSingleInput(self): + + class MaskedLayer(base_layers.Layer): + + def call(self, inputs, mask=None): + if mask is not None: + return inputs * mask + return inputs + + def compute_mask(self, inputs, mask=None): + return array_ops.ones_like(inputs) + + if context.in_graph_mode(): + x = network_layers.Input(shape=(32,)) + y = MaskedLayer()(x) # pylint: disable=not-callable + network = network_layers.GraphNetwork(x, y) + + # test callability on Input + x_2 = network_layers.Input(shape=(32,)) + y_2 = network(x_2) + self.assertEqual(y_2.get_shape().as_list(), [None, 32]) + + # test callability on regular tensor + x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32)) + y_2 = network(x_2) + self.assertEqual(y_2.get_shape().as_list(), [None, 32]) + else: + a = constant_op.constant([2] * 32) + mask = constant_op.constant([0, 1] * 16) + a._keras_mask = mask + b = MaskedLayer().apply(a) + self.assertTrue(hasattr(b, '_keras_mask')) + self.assertAllEqual(self.evaluate(array_ops.ones_like(mask)), + self.evaluate(getattr(b, '_keras_mask'))) + self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b)) + + +class DeferredModeTest(test.TestCase): + + def testDeferredTensorAttributes(self): + x = base_layers._DeferredTensor(shape=(None, 2), dtype='float32', name='x') + self.assertEqual(str(x), + 'DeferredTensor(\'x\', shape=(?, 2), dtype=float32)') + self.assertEqual(repr(x), + '<_DeferredTensor \'x\' shape=(?, 2) dtype=float32>') + + @test_util.run_in_graph_and_eager_modes() + def testSimpleNetworkBuilding(self): + inputs = network_layers.Input(shape=(32,)) + if context.in_eager_mode(): + self.assertIsInstance(inputs, base_layers._DeferredTensor) + self.assertEqual(inputs.dtype.name, 'float32') + self.assertEqual(inputs.shape.as_list(), [None, 32]) + + x = core_layers.Dense(2)(inputs) + if context.in_eager_mode(): + self.assertIsInstance(x, base_layers._DeferredTensor) + self.assertEqual(x.dtype.name, 'float32') + self.assertEqual(x.shape.as_list(), [None, 2]) + + outputs = core_layers.Dense(4)(x) + network = network_layers.GraphNetwork(inputs, outputs) + self.assertIsInstance(network, network_layers.GraphNetwork) + + if context.in_eager_mode(): + # It should be possible to call such a network on EagerTensors. + inputs = constant_op.constant( + np.random.random((10, 32)).astype('float32')) + outputs = network(inputs) + self.assertEqual(outputs.shape.as_list(), [10, 4]) + + @test_util.run_in_graph_and_eager_modes() + def testMultiIONetworkbuilding(self): + input_a = network_layers.Input(shape=(32,)) + input_b = network_layers.Input(shape=(16,)) + a = core_layers.Dense(16)(input_a) + + class AddLayer(base_layers.Layer): + + def call(self, inputs): + return inputs[0] + inputs[1] + + def _compute_output_shape(self, input_shape): + return input_shape[0] + + c = AddLayer()([a, input_b]) # pylint: disable=not-callable + c = core_layers.Dense(2)(c) + + network = network_layers.GraphNetwork([input_a, input_b], [a, c]) + if context.in_eager_mode(): + a_val = constant_op.constant( + np.random.random((10, 32)).astype('float32')) + b_val = constant_op.constant( + np.random.random((10, 16)).astype('float32')) + outputs = network([a_val, b_val]) + self.assertEqual(len(outputs), 2) + self.assertEqual(outputs[0].shape.as_list(), [10, 16]) + self.assertEqual(outputs[1].shape.as_list(), [10, 2]) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/layers/utils.py b/tensorflow/python/layers/utils.py index 7c71d3c952c..766a6800d44 100644 --- a/tensorflow/python/layers/utils.py +++ b/tensorflow/python/layers/utils.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import variables from tensorflow.python.ops import control_flow_ops from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util +from tensorflow.python.util import nest def convert_data_format(data_format, ndim): @@ -232,3 +233,19 @@ def constant_value(pred): else: raise TypeError('`pred` must be a Tensor, a Variable, or a Python bool.') return pred_value + + +def object_list_uid(object_list): + """Creates a single string from object ids.""" + object_list = nest.flatten(object_list) + return ', '.join([str(abs(id(x))) for x in object_list]) + + +def static_shape(x): + """Get the static shape of a Tensor, or None if it is unavailable.""" + if x is None: + return None + try: + return tuple(x.get_shape().as_list()) + except ValueError: + return None diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index a62847614c6..b30125761fc 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -176,7 +176,8 @@ string PyExcFetch() { } // Calls the registered py function through the trampoline. -Status DoCallPyFunc(PyCall* call) { +Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { + *out_log_on_error = true; PyObject* trampoline = GetPyTrampoline(); if (trampoline == nullptr) { return errors::InvalidArgument( @@ -196,6 +197,7 @@ Status DoCallPyFunc(PyCall* call) { PyErr_ExceptionMatches(PyExc_TypeError)) { return errors::InvalidArgument(PyExcFetch()); } else if (PyErr_ExceptionMatches(PyExc_StopIteration)) { + *out_log_on_error = false; return errors::OutOfRange(PyExcFetch()); } else if (PyErr_ExceptionMatches(PyExc_MemoryError)) { return errors::ResourceExhausted(PyExcFetch()); @@ -426,11 +428,19 @@ class PyFuncOp : public OpKernel { PyGILState_STATE py_threadstate; py_threadstate = PyGILState_Ensure(); - Status s = DoCallPyFunc(&call); + bool log_on_error; + Status s = DoCallPyFunc(&call, &log_on_error); PyGILState_Release(py_threadstate); // Ensures that GIL is released even when !s.ok(). - OP_REQUIRES_OK(ctx, s); + if (!s.ok()) { + if (log_on_error) { + ctx->CtxFailureWithWarning(s); + } else { + ctx->CtxFailure(s); + } + return; + } OP_REQUIRES(ctx, static_cast(call.out.size()) == ctx->num_outputs(), errors::InvalidArgument(token_, " returns ", call.out.size(), diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 1211b2e9230..dacc2947fe3 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -573,7 +573,9 @@ class HessianVectorProductTest(test_util.TensorFlowTestCase): self.assertAllClose(hess_v_value, hess_v_actual) -@test_util.with_c_api +# TODO(skyewm): reenable C API once +# ControlFlowContext._RemoveExternalControlEdges works with C API enabled +# @test_util.with_c_api class HessianTest(test_util.TensorFlowTestCase): def testHessian1D(self): diff --git a/tensorflow/python/ops/linalg/linear_operator_test_util.py b/tensorflow/python/ops/linalg/linear_operator_test_util.py index 3d0ea3e11be..2c11f90e6d9 100644 --- a/tensorflow/python/ops/linalg/linear_operator_test_util.py +++ b/tensorflow/python/ops/linalg/linear_operator_test_util.py @@ -66,11 +66,23 @@ class LinearOperatorDerivedClassTest(test.TestCase): rtol = self._rtol[dtype] self.assertAllClose(x, y, atol=atol, rtol=rtol) + @property + def _adjoint_options(self): + return [False, True] + + @property + def _adjoint_arg_options(self): + return [False, True] + @property def _dtypes_to_test(self): # TODO(langmore) Test tf.float16 once tf.matrix_solve works in 16bit. return [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128] + @property + def _use_placeholder_options(self): + return [False, True] + @abc.abstractproperty def _shapes_to_test(self): """Returns list of tuples, each is one shape that will be tested.""" @@ -151,7 +163,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_to_dense(self): self._skip_if_tests_to_skip_contains("to_dense") - for use_placeholder in False, True: + for use_placeholder in self._use_placeholder_options: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: @@ -166,7 +178,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_det(self): self._skip_if_tests_to_skip_contains("det") - for use_placeholder in False, True: + for use_placeholder in self._use_placeholder_options: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: @@ -183,7 +195,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_log_abs_det(self): self._skip_if_tests_to_skip_contains("log_abs_det") - for use_placeholder in False, True: + for use_placeholder in self._use_placeholder_options: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: @@ -200,11 +212,11 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_matmul(self): self._skip_if_tests_to_skip_contains("matmul") - for use_placeholder in False, True: + for use_placeholder in self._use_placeholder_options: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: - for adjoint in False, True: - for adjoint_arg in False, True: + for adjoint in self._adjoint_options: + for adjoint_arg in self._adjoint_arg_options: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( @@ -228,11 +240,11 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_solve(self): self._skip_if_tests_to_skip_contains("solve") - for use_placeholder in False, True: + for use_placeholder in self._use_placeholder_options: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: - for adjoint in False, True: - for adjoint_arg in False, True: + for adjoint in self._adjoint_options: + for adjoint_arg in self._adjoint_arg_options: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( @@ -257,7 +269,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_trace(self): self._skip_if_tests_to_skip_contains("trace") - for use_placeholder in False, True: + for use_placeholder in self._use_placeholder_options: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: @@ -274,7 +286,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_add_to_tensor(self): self._skip_if_tests_to_skip_contains("add_to_tensor") - for use_placeholder in False, True: + for use_placeholder in self._use_placeholder_options: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: @@ -293,7 +305,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_diag_part(self): self._skip_if_tests_to_skip_contains("diag_part") - for use_placeholder in False, True: + for use_placeholder in self._use_placeholder_options: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index fa58ffc37e2..156e415735f 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -561,9 +561,9 @@ class TextFileStringTableInitializer(TextFileInitializer): The path must be accessible from wherever the graph is initialized (eg. trainer or eval workers). The filename may be a scalar `Tensor`. key_column_index: The column index from the text file to get the keys - from. The default is 0 that represents the whole line content. + from. The default is to use the line number, starting from zero. value_column_index: The column index from the text file to get the - values from. The default is to use the line number, starting from zero. + values from. The default is to use the whole line content. vocab_size: The number of elements in the file, if known. delimiter: The delimiter to separate fields in a line. name: Optional name for the op. @@ -613,9 +613,9 @@ class TextFileIdTableInitializer(TextFileInitializer): The path must be accessible from wherever the graph is initialized (eg. trainer or eval workers). The filename may be a scalar `Tensor`. key_column_index: The column index from the text file to get the `key` + values from. The default is to use the whole line content. + value_column_index: The column index from the text file to get the `value` values from. The default is to use the line number, starting from zero. - value_column_index: The column index from the text file ro get the `value` - values from. The default is 0 that represents the whole line content. vocab_size: The number of elements in the file, if known. delimiter: The delimiter to separate fields in a line. name: Optional name for the op. @@ -864,7 +864,10 @@ def index_table_from_file(vocabulary_file=None, default_value=-1, hasher_spec=FastHashSpec, key_dtype=dtypes.string, - name=None): + name=None, + key_column_index=TextFileIndex.WHOLE_LINE, + value_column_index=TextFileIndex.LINE_NUMBER, + delimiter="\t"): """Returns a lookup table that converts a string tensor into int64 IDs. This operation constructs a lookup table to convert tensor of strings into @@ -881,6 +884,16 @@ def index_table_from_file(vocabulary_file=None, The underlying table must be initialized by calling `tf.tables_initializer.run()` or `table.init.run()` once. + To specify multi-column vocabulary files, use key_column_index and + value_column_index and delimiter. + + - TextFileIndex.LINE_NUMBER means use the line number starting from zero, + expects data type int64. + - TextFileIndex.WHOLE_LINE means use the whole line content, expects data + type string. + - A value >=0 means use the index (starting at zero) of the split line based + on `delimiter`. + Sample Usages: If we have a vocabulary file "test.txt" with the following content: @@ -912,6 +925,11 @@ def index_table_from_file(vocabulary_file=None, assignation of out-of-vocabulary buckets. key_dtype: The `key` data type. name: A name for this op (optional). + key_column_index: The column index from the text file to get the `key` + values from. The default is to use the whole line content. + value_column_index: The column index from the text file to get the `value` + values from. The default is to use the line number, starting from zero. + delimiter: The delimiter to separate fields in a line. Returns: The lookup table to map a `key_dtype` `Tensor` to index `int64` `Tensor`. @@ -944,19 +962,22 @@ def index_table_from_file(vocabulary_file=None, # Keep the shared_name: # ____ shared_name = "hash_table_%s_%d_%s_%s" % (vocabulary_file, vocab_size, - TextFileIndex.WHOLE_LINE, - TextFileIndex.LINE_NUMBER) + key_column_index, + value_column_index) else: # Keep the shared_name # ___ shared_name = "hash_table_%s_%s_%s" % (vocabulary_file, - TextFileIndex.WHOLE_LINE, - TextFileIndex.LINE_NUMBER) + key_column_index, + value_column_index) init = TextFileIdTableInitializer( vocabulary_file, vocab_size=vocab_size, key_dtype=dtypes.int64 if key_dtype.is_integer else key_dtype, - name="table_init") + name="table_init", + key_column_index=key_column_index, + value_column_index=value_column_index, + delimiter=delimiter) table = HashTable( init, default_value, shared_name=shared_name, name=hash_table_scope) @@ -1074,7 +1095,10 @@ def index_table_from_tensor(vocabulary_list, def index_to_string_table_from_file(vocabulary_file, vocab_size=None, default_value="UNK", - name=None): + name=None, + key_column_index=TextFileIndex.LINE_NUMBER, + value_column_index=TextFileIndex.WHOLE_LINE, + delimiter="\t"): """Returns a lookup table that maps a `Tensor` of indices into strings. This operation constructs a lookup table to map int64 indices into string @@ -1088,6 +1112,16 @@ def index_to_string_table_from_file(vocabulary_file, The underlying table must be initialized by calling `tf.tables_initializer.run()` or `table.init.run()` once. + To specify multi-column vocabulary files, use key_column_index and + value_column_index and delimiter. + + - TextFileIndex.LINE_NUMBER means use the line number starting from zero, + expects data type int64. + - TextFileIndex.WHOLE_LINE means use the whole line content, expects data + type string. + - A value >=0 means use the index (starting at zero) of the split line based + on `delimiter`. + Sample Usages: If we have a vocabulary file "test.txt" with the following content: @@ -1114,6 +1148,11 @@ def index_to_string_table_from_file(vocabulary_file, vocab_size: Number of the elements in the vocabulary, if known. default_value: The value to use for out-of-vocabulary indices. name: A name for this op (optional). + key_column_index: The column index from the text file to get the `key` + values from. The default is to use the line number, starting from zero. + value_column_index: The column index from the text file to get the `value` + values from. The default is to use the whole line content. + delimiter: The delimiter to separate fields in a line. Returns: The lookup table to map a string values associated to a given index `int64` @@ -1134,15 +1173,19 @@ def index_to_string_table_from_file(vocabulary_file, # Keep a shared_name # ____ shared_name = "hash_table_%s_%d_%s_%s" % (vocabulary_file, vocab_size, - TextFileIndex.LINE_NUMBER, - TextFileIndex.WHOLE_LINE) + key_column_index, + value_column_index) else: # Keep a shared_name ___ - shared_name = "hash_table_%s_%s_%s" % (vocabulary_file, - TextFileIndex.LINE_NUMBER, - TextFileIndex.WHOLE_LINE) + shared_name = "hash_table_%s_%s_%s" % (vocabulary_file, key_column_index, + value_column_index) init = TextFileStringTableInitializer( - vocabulary_file, vocab_size=vocab_size, name="table_init") + vocabulary_file, + vocab_size=vocab_size, + name="table_init", + key_column_index=key_column_index, + value_column_index=value_column_index, + delimiter=delimiter) # TODO(yleon): Use a more effienct structure. return HashTable(init, default_value, shared_name=shared_name, name=scope) diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 870c4f40623..d30f6b92ad4 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -1511,6 +1511,56 @@ def false_positives_at_thresholds(labels, predictions, thresholds, weights=None, return values['fp'], update_ops['fp'] +def true_negatives(labels, predictions, weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Sum the weights of true_negatives. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + labels: The ground truth values, a `Tensor` whose dimensions must match + `predictions`. Will be cast to `bool`. + predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will + be cast to `bool`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that the metric + value variable should be added to. + updates_collections: An optional list of collections that the metric update + ops should be added to. + name: An optional variable_scope name. + + Returns: + value_tensor: A `Tensor` representing the current value of the metric. + update_op: An operation that accumulates the error from a batch of data. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + RuntimeError: If eager execution is enabled. + """ + if context.in_eager_mode(): + raise RuntimeError('tf.metrics.true_negatives is not ' + 'supported when eager execution is enabled.') + + with variable_scope.variable_scope( + name, 'true_negatives', (predictions, labels, weights)): + + predictions, labels, weights = _remove_squeezable_dimensions( + predictions=math_ops.cast(predictions, dtype=dtypes.bool), + labels=math_ops.cast(labels, dtype=dtypes.bool), + weights=weights) + is_true_negative = math_ops.logical_and(math_ops.equal(labels, False), + math_ops.equal(predictions, False)) + return _count_condition(is_true_negative, weights, metrics_collections, + updates_collections) + + def true_negatives_at_thresholds(labels, predictions, thresholds, weights=None, metrics_collections=None, updates_collections=None, diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py index 24ef70c6f4d..98578b799a8 100644 --- a/tensorflow/python/ops/template.py +++ b/tensorflow/python/ops/template.py @@ -21,6 +21,7 @@ from __future__ import print_function import functools import traceback +from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging @@ -138,6 +139,10 @@ def make_template(name_, func_, create_scope_now_=False, unique_name_=None, """ if kwargs: func_ = functools.partial(func_, **kwargs) + if context.in_eager_mode(): + return EagerTemplate( + name_, func_, create_scope_now=create_scope_now_, + unique_name=unique_name_, custom_getter=custom_getter_) return Template( name_, func_, create_scope_now=create_scope_now_, unique_name=unique_name_, custom_getter=custom_getter_) @@ -336,3 +341,184 @@ class Template(object): def var_scope(self): """Returns the variable scope object created by this Template.""" return self._variable_scope + + +class EagerTemplate(Template): + """Wrap a function to aid in variable sharing in Eager mode. + + Templates are functions that create variables the first time they are called + and reuse them thereafter. See `make_template` for full documentation. + + Note: By default, the full variable scope is captured at the time of first + call. If `create_scope_now` is passed as True to the constructor, the full + scope will be captured there, but no variables will be created until the first + call. + """ + + def __init__(self, name, func, create_scope_now=False, unique_name=None, + custom_getter=None): + """Creates a template for the given function. + + Args: + name: A name for the scope created by this template. The + name will be made unique by appending `_N` to the it (see how + `tf.variable_scope` treats the `default_name` for details). + func: The function to apply each time. + create_scope_now: Whether to create the scope at Template construction + time, rather than first call. Defaults to false. Creating the scope at + construction time may be more convenient if the template is passed + through much lower level code, and you want to be sure of the scope + name without knowing exactly where it will be first called. If set to + True, the scope will be created in the constructor, and all subsequent + times in __call__, leading to a trailing numeral being added to the + names of all created Tensors. If set to False, the scope will be created + at the first call location. + unique_name: When used, it overrides name_ and is not made unique. If a + template of the same scope/unique_name already exists and reuse is + false, an error is raised. Defaults to None. + custom_getter: optional custom getter to pass to variable_scope() + + Raises: + RuntimeError: if eager mode is not enabled. + ValueError: if the name is None or unique_name is provided. + """ + if not context.in_eager_mode(): + raise RuntimeError( + "{} objects can only be used when eager execution is enabled, use " + "tf.Template for graph construction". + format(type(self))) + if unique_name: + raise ValueError("unique_name cannot be used in eager mode.") + super(EagerTemplate, self).__init__(name, func, create_scope_now, + unique_name, custom_getter) + # Create an eager variable store only if the current variable store cannot + # store eager variables. This should allow for correct nesting. + default_vstore = variable_scope._get_default_variable_store() # pylint: disable=protected-access + if default_vstore._store_eager_variables: # pylint: disable=protected-access + raise ValueError("Nested EagerTemaplates are not currently supported.") + else: + self._eager_variable_store = variable_scope.EagerVariableStore() + + def _call_func(self, args, kwargs, check_for_new_variables): + try: + vars_at_start = self._eager_variable_store.variables() + trainable_at_start = self._eager_variable_store.trainable_variables() + + result = self._func(*args, **kwargs) + if check_for_new_variables: + trainable_variables = self._eager_variable_store.trainable_variables() + # If a variable that we intend to train is created as a side effect + # of creating a template, then that is almost certainly an error. + if len(trainable_at_start) != len(trainable_variables): + raise ValueError("Trainable variable created when calling a template " + "after the first time, perhaps you used tf.Variable " + "when you meant tf.get_variable: %s" % + list(set(trainable_variables) - + set(trainable_at_start))) + + # Non-trainable tracking variables are a legitimate reason why a new + # variable would be created, but it is a relatively advanced use-case, + # so log it. + variables = self._eager_variable_store.variables() + if len(vars_at_start) != len(variables): + logging.info("New variables created when calling a template after " + "the first time, perhaps you used tf.Variable when you " + "meant tf.get_variable: %s", + list(set(variables) - set(vars_at_start))) + return result + except Exception as exc: + # Reraise the exception, but append the original definition to the + # trace. + args = exc.args + if not args: + arg0 = "" + else: + arg0 = args[0] + trace = "".join(_skip_common_stack_elements(self._stacktrace, + traceback.format_stack())) + arg0 = "%s\n\noriginally defined at:\n%s" % (arg0, trace) + new_args = [arg0] + new_args.extend(args[1:]) + exc.args = tuple(new_args) + raise + + def __call__(self, *args, **kwargs): + if self._variable_scope: + if self._variables_created: + # This is not the first visit to __call__, so variables have already + # been created, and we want to reuse them. + with variable_scope.variable_scope(self._variable_scope, + reuse=variable_scope.AUTO_REUSE): + with self._eager_variable_store.as_default(): + return self._call_func(args, kwargs, check_for_new_variables=True) + else: + # This is the first visit to __call__, but the scope has already been + # created in the constructor. Set _variables_created after the inner + # function is successfully called so that subsequent calls take the if + # branch above. + with variable_scope.variable_scope(self._variable_scope, + reuse=variable_scope.AUTO_REUSE): + with self._eager_variable_store.as_default(): + result = self._call_func(args, kwargs, + check_for_new_variables=False) + self._variables_created = True + return result + else: + # The scope was not created at construction time, so create it here. + # Subsequent calls should reuse variables. + with variable_scope.variable_scope( + self._unique_name, self._name, + custom_getter=self._custom_getter) as vs: + self._variable_scope = vs + with self._eager_variable_store.as_default(): + result = self._call_func(args, kwargs, + check_for_new_variables=False) + self._variables_created = True + return result + + @property + def name(self): + """Returns the name given to this Template.""" + return self._name + + @property + def func(self): + """Returns the func given to this Template.""" + return self._func + + @property + def variable_scope(self): + """Returns the variable scope object created by this Template.""" + return self._variable_scope + + @property + def variable_scope_name(self): + """Returns the variable scope name created by this Template.""" + if self._variable_scope: + name = self._variable_scope.name + # To prevent partial matches on the scope_name, we add '/' at the end. + return name if name[-1] == "/" else name + "/" + + @property + def variables(self): + """Returns the list of trainable variables created by the Template.""" + # Currently there is no local variable in Eager mode. + return self._eager_variable_store.variables() + + @property + def trainable_variables(self): + """Returns the list of trainable variables created by the Template.""" + # Currently there is no local variable in Eager mode. + return self._eager_variable_store.trainable_variables() + + @property + def global_variables(self): + """Returns the list of global variables created by the Template.""" + # Currently there is no local variable in Eager mode. + return self.variables + + @property + def local_variables(self): + """Returns the list of global variables created by the Template.""" + # Currently there is no local variable in Eager mode. + return [] diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index ea5354c1d6a..605654d9be7 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -36,6 +36,9 @@ from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.util import tf_should_use +# TODO(ebrevdo): Set to True in Dec. 4, 2017. +_ENABLE_IDENTICAL_ELEMENT_SHAPES = False + # _GraphTensorArray accesses many of the hidden generated ops, but is in # fact built to wrap these methods. @@ -146,6 +149,10 @@ class _GraphTensorArray(object): # write into the TensorArray from a Tensor with a set device # will retroactively set the device value of this op. def create(): + """Create the TensorArray op.""" + ta_kwargs = {} + if _ENABLE_IDENTICAL_ELEMENT_SHAPES: + ta_kwargs["identical_element_shapes"] = infer_shape return gen_data_flow_ops._tensor_array_v3( dtype=dtype, size=size, @@ -153,7 +160,8 @@ class _GraphTensorArray(object): dynamic_size=dynamic_size, clear_after_read=clear_after_read, tensor_array_name=tensor_array_name, - name=scope) + name=scope, + **ta_kwargs) if colocate_with_first_write_call: with ops.device(None), ops.colocate_with(None, ignore_existing=True): self._handle, self._flow = create() diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 9a0ff755941..91dea12da23 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -1225,11 +1225,12 @@ class EagerVariableStore(object): return with_variable_store(self._store) def variables(self): - return self._store._vars.values() # pylint: disable=protected-access + return sorted(self._store._vars.values(), key=lambda x: x.name) # pylint: disable=protected-access def trainable_variables(self): # pylint: disable=protected-access - return [x for x in self._store._vars.values() if x._trainable] + return sorted([x for x in self._store._vars.values() if x._trainable], + key=lambda x: x.name) # pylint: enable=protected-access @@ -1827,7 +1828,13 @@ class variable_scope(object): # pylint: disable=invalid-name self._current_name_scope = None def __enter__(self): - if self._in_graph_mode: + # If the default graph is building a function, then we should not replace it + # with the cached graph. + if ops.get_default_graph().building_function: + self._building_function = True + else: + self._building_function = False + if self._in_graph_mode and not self._building_function: self._graph_context_manager = self._graph.as_default() self._graph_context_manager.__enter__() if self._cached_pure_variable_scope is not None: @@ -1906,7 +1913,7 @@ class variable_scope(object): # pylint: disable=invalid-name type_arg, value_arg, traceback_arg) if self._current_name_scope: self._current_name_scope.__exit__(type_arg, value_arg, traceback_arg) - if self._in_graph_mode: + if self._in_graph_mode and not self._building_function: self._graph_context_manager.__exit__(type_arg, value_arg, traceback_arg) diff --git a/tensorflow/python/profiler/model_analyzer.py b/tensorflow/python/profiler/model_analyzer.py index 040a4891637..46a921c0a13 100644 --- a/tensorflow/python/profiler/model_analyzer.py +++ b/tensorflow/python/profiler/model_analyzer.py @@ -20,6 +20,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys + import six from google.protobuf import message @@ -206,8 +208,8 @@ class Profiler(object): try: tfprof_node.ParseFromString( print_mdl.Profile('code'.encode('utf-8'), opts.SerializeToString())) - except message.DecodeError as _: - pass + except message.DecodeError as e: + sys.stderr.write('Cannot parse returned proto: %s.\n' % e) return tfprof_node def profile_operations(self, options): @@ -223,8 +225,8 @@ class Profiler(object): try: tfprof_node.ParseFromString( print_mdl.Profile('op'.encode('utf-8'), opts.SerializeToString())) - except message.DecodeError as _: - pass + except message.DecodeError as e: + sys.stderr.write('Cannot parse returned proto: %s.\n' % e) return tfprof_node def profile_name_scope(self, options): @@ -240,8 +242,8 @@ class Profiler(object): try: tfprof_node.ParseFromString( print_mdl.Profile('scope'.encode('utf-8'), opts.SerializeToString())) - except message.DecodeError as _: - pass + except message.DecodeError as e: + sys.stderr.write('Cannot parse returned proto: %s.\n' % e) return tfprof_node def profile_graph(self, options): @@ -257,8 +259,8 @@ class Profiler(object): try: tfprof_node.ParseFromString( print_mdl.Profile('graph'.encode('utf-8'), opts.SerializeToString())) - except message.DecodeError as _: - pass + except message.DecodeError as e: + sys.stderr.write('Cannot parse returned proto: %s.\n' % e) return tfprof_node def advise(self, options): @@ -331,9 +333,8 @@ def profile(graph, opts.SerializeToString()) try: tfprof_node.ParseFromString(ret) - except message.DecodeError as _: - pass - # sys.stderr.write('Cannot parse returned proto: %s.\n' % e) + except message.DecodeError as e: + sys.stderr.write('Cannot parse returned proto: %s.\n' % e) elif cmd == 'graph' or cmd == 'scope': tfprof_node = tfprof_output_pb2.GraphNodeProto() @@ -345,9 +346,8 @@ def profile(graph, opts.SerializeToString()) try: tfprof_node.ParseFromString(ret) - except message.DecodeError as _: - pass - # sys.stderr.write('Cannot parse returned proto: %s.\n' % e) + except message.DecodeError as e: + sys.stderr.write('Cannot parse returned proto: %s.\n' % e) else: raise errors.InvalidArgumentError( None, None, 'unknown cmd: %s\n' % cmd) diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 5ca0e572869..82b154164e8 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -24,13 +24,16 @@ limitations under the License. %rename("%s") TFE_Py_RegisterExceptionClass; %rename("%s") TFE_Py_Execute; %rename("%s") TFE_Py_UID; -%rename("%s") TFE_Py_NewTape; -%rename("%s") TFE_Py_TapeShouldRecord; -%rename("%s") TFE_Py_TapeWatch; -%rename("%s") TFE_Py_TapeDeleteTrace; -%rename("%s") TFE_Py_TapeRecordOperation; +%rename("%s") TFE_Py_TapeStackPushNew; +%rename("%s") TFE_Py_TapeStackPush; +%rename("%s") TFE_Py_TapeStackPop; +%rename("%s") TFE_Py_TapeStackIsEmpty; +%rename("%s") TFE_Py_TapeStackShouldRecord; +%rename("%s") TFE_Py_TapeStackWatch; +%rename("%s") TFE_Py_TapeStackDeleteTrace; +%rename("%s") TFE_Py_TapeStackRecordOperation; +%rename("%s") TFE_Py_TapeStackWatchVariable; %rename("%s") TFE_Py_TapeGradient; -%rename("%s") TFE_Py_TapeWatchVariable; %rename("%s") TFE_Py_TapeWatchedVariables; %rename("%s") TFE_NewContextOptions; %rename("%s") TFE_ContextOptionsSetConfig; diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 9f5e8ec9389..b31d02eb8d7 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -381,7 +381,7 @@ class Optimizer(object): loss: A Tensor containing the value to minimize. var_list: Optional list or tuple of `tf.Variable` to update to minimize `loss`. Defaults to the list of variables collected in the graph - under the key `GraphKey.TRAINABLE_VARIABLES`. + under the key `GraphKeys.TRAINABLE_VARIABLES`. gate_gradients: How to gate the computation of gradients. Can be `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. aggregation_method: Specifies the method used to combine gradient terms. diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt index b6f9eea2dea..07b8d900da5 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.Model" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -152,7 +152,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " } member_method { name: "evaluate_generator" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt index 5076434dbb5..546bac44e4c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -153,7 +153,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'sample_weight_mode\', \'weighted_metrics\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" @@ -173,11 +173,11 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'32\', \'10\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " } member_method { name: "fit_generator" - argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'initial_epoch\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'0\'], " + argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], " } member_method { name: "from_config" @@ -241,7 +241,7 @@ tf_class { } member_method { name: "predict_classes" - argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'1\'], " + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], " } member_method { name: "predict_generator" @@ -253,7 +253,7 @@ tf_class { } member_method { name: "predict_proba" - argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'1\'], " + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], " } member_method { name: "reset_states" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.fashion_mnist.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.fashion_mnist.pbtxt new file mode 100644 index 00000000000..791cfda2334 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.fashion_mnist.pbtxt @@ -0,0 +1,3 @@ +path: "tensorflow.keras.datasets.fashion_mnist" +tf_module { +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.pbtxt index d4aa436f328..36e3aafbe4d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.datasets.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.pbtxt @@ -12,6 +12,10 @@ tf_module { name: "cifar100" mtype: "" } + member { + name: "fashion_mnist" + mtype: "" + } member { name: "imdb" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt index a0906e62cf5..8c2b110c6d3 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt @@ -191,7 +191,7 @@ tf_class { argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { - name: "reccurent_conv" + name: "recurrent_conv" argspec: "args=[\'self\', \'x\', \'w\'], varargs=None, keywords=None, defaults=None" } member_method { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt index b2df5fba8fd..49841237cef 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.InputLayer" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt index 7867e3c1fd3..f289664ba27 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt @@ -93,7 +93,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'1\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt index 0fb6e84f8de..d7887286125 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt @@ -93,7 +93,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'1\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt index 34c9efb3ca0..dedef65ff93 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt @@ -9,10 +9,6 @@ tf_class { name: "activity_regularizer" mtype: "" } - member { - name: "constraints" - mtype: "" - } member { name: "dtype" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt index 9cee68874a9..313b3a9e155 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt @@ -8,10 +8,6 @@ tf_class { name: "activity_regularizer" mtype: "" } - member { - name: "constraints" - mtype: "" - } member { name: "dtype" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt index af9a44086fd..4e522813a5a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.models.Model" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -152,7 +152,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " } member_method { name: "evaluate_generator" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt index 5034fdff2a6..ddbb358c84c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -153,7 +153,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'sample_weight_mode\', \'weighted_metrics\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" @@ -173,11 +173,11 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'32\', \'10\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " } member_method { name: "fit_generator" - argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'initial_epoch\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'0\'], " + argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], " } member_method { name: "from_config" @@ -241,7 +241,7 @@ tf_class { } member_method { name: "predict_classes" - argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'1\'], " + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], " } member_method { name: "predict_generator" @@ -253,7 +253,7 @@ tf_class { } member_method { name: "predict_proba" - argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'1\'], " + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], " } member_method { name: "reset_states" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt index 8ad1f32551d..66cd37bb3a3 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.keras.preprocessing.image.DirectoryIterator" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" @@ -11,6 +12,10 @@ tf_class { name: "next" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "reset" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-iterator.pbtxt index d30462a8eb6..69488d63bf1 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-iterator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-iterator.pbtxt @@ -1,11 +1,16 @@ path: "tensorflow.keras.preprocessing.image.Iterator" tf_class { is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" argspec: "args=[\'self\', \'n\', \'batch_size\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "reset" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt index 841f1c5585e..4ef6e6e99e3 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.keras.preprocessing.image.NumpyArrayIterator" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" @@ -11,6 +12,10 @@ tf_class { name: "next" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "reset" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.pbtxt index 56526870335..d28fef69651 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.pbtxt @@ -34,7 +34,7 @@ tf_module { } member_method { name: "load_img" - argspec: "args=[\'path\', \'grayscale\', \'target_size\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'path\', \'grayscale\', \'target_size\', \'interpolation\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'nearest\'], " } member_method { name: "random_channel_shift" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-generator-enqueuer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.utils.-generator-enqueuer.pbtxt index bf27a97cf25..1c5868e711b 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.utils.-generator-enqueuer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.utils.-generator-enqueuer.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'generator\', \'use_multiprocessing\', \'wait_time\', \'random_seed\'], varargs=None, keywords=None, defaults=[\'False\', \'0.05\', \'None\'], " + argspec: "args=[\'self\', \'generator\', \'use_multiprocessing\', \'wait_time\', \'seed\'], varargs=None, keywords=None, defaults=[\'False\', \'0.05\', \'None\'], " } member_method { name: "get" diff --git a/tensorflow/tools/api/golden/tensorflow.metrics.pbtxt b/tensorflow/tools/api/golden/tensorflow.metrics.pbtxt index 85088834b79..e9b996c9f53 100644 --- a/tensorflow/tools/api/golden/tensorflow.metrics.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.metrics.pbtxt @@ -116,6 +116,10 @@ tf_module { name: "specificity_at_sensitivity" argspec: "args=[\'labels\', \'predictions\', \'sensitivity\', \'weights\', \'num_thresholds\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'200\', \'None\', \'None\', \'None\'], " } + member_method { + name: "true_negatives" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } member_method { name: "true_negatives_at_thresholds" argspec: "args=[\'labels\', \'predictions\', \'thresholds\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh index 55f32f40f8d..c27f4953e3d 100755 --- a/tensorflow/tools/ci_build/ci_parameterized_build.sh +++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh @@ -546,8 +546,9 @@ echo "" TMP_DIR="" DOCKERFILE_FLAG="" -if [[ "${TF_BUILD_PYTHON_VERSION}" == "python3.5" ]]; then - # Modify Dockerfile for Python3.5 build +if [[ "${TF_BUILD_PYTHON_VERSION}" == "python3.5" ] || + ["${TF_BUILD_PYTHON_VERSION}" == "python3.6" ]]; then + # Modify Dockerfile for Python3.5 | Python3.6 build TMP_DIR=$(mktemp -d) echo "Docker build will occur in temporary directory: ${TMP_DIR}" @@ -563,10 +564,10 @@ if [[ "${TF_BUILD_PYTHON_VERSION}" == "python3.5" ]]; then # Replace a line in the Dockerfile if sed -i \ - 's/RUN \/install\/install_pip_packages.sh/RUN \/install\/install_python3.5_pip_packages.sh/g' \ + "s/RUN \/install\/install_pip_packages.sh/RUN \/install\/install_${TF_BUILD_PYTHON_VERSION}_pip_packages.sh/g" \ "${DOCKERFILE}" then - echo "Copied and modified Dockerfile for Python 3.5 build: ${DOCKERFILE}" + echo "Copied and modified Dockerfile for ${TF_BUILD_PYTHON_VERSION} build: ${DOCKERFILE}" else die "ERROR: Faild to copy and modify Dockerfile: ${DOCKERFILE}" fi diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh index 81bce95d543..479242aa437 100755 --- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh @@ -18,33 +18,12 @@ # TODO(cais): Remove this file once we upgrade to ubuntu:16.04 docker images for # Python 3.5 builds. +# LINT.IfChange + # fkrull/deadsnakes is for Python3.5 add-apt-repository -y ppa:fkrull/deadsnakes apt-get update -set +e -# Upgrade swig to 3.0.8 -SWIG_VERSION="3.0.8" -swig_ver_flat=$(echo $SWIG_VERSION | sed 's/\.//g' | sed 's/^0*//g') -local_swig_ver=$(swig -version | grep -i version | awk '{print $3}') -local_swig_ver_flat=$(echo $local_swig_ver | sed 's/\.//g' | sed 's/^0*//g') -if [[ -z $local_swig_ver_flat ]]; then - local_swig_ver_flat=0 -fi -if (( $local_swig_ver_flat < $swig_ver_flat )); then - set -e - wget -q http://downloads.sourceforge.net/swig/swig-3.0.8.tar.gz - tar xzf swig-3.0.8.tar.gz - pushd swig-3.0.8 - apt-get install -y --no-install-recommends libpcre3-dev - ./configure - make - make install - rm -f /usr/bin/swig - ln -s /usr/local/bin/swig /usr/bin/swig - popd - rm -rf swig-3.0.8 swig-3.0.8.tar.gz -fi set -e # Install Python 3.5 and dev library apt-get install -y --no-install-recommends python3.5 libpython3.5-dev @@ -92,3 +71,5 @@ pip3.5 install portpicker pip3.5 install werkzeug pip3.5 install grpcio + +# LINT.ThenChange(//tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh) diff --git a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh new file mode 100755 index 00000000000..c354aaa154e --- /dev/null +++ b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh @@ -0,0 +1,75 @@ +#!/usr/bin/env bash +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Install packages required by Python3.6 build + +# TODO(amitpatankar): Remove this file once we upgrade to ubuntu:16.04 +# docker images for Python 3.6 builds. + +# LINT.IfChange + +# fkrull/deadsnakes is for Python3.6 +add-apt-repository -y ppa:fkrull/deadsnakes +apt-get update + +set -e +# Install Python 3.6 and dev library +apt-get install -y --no-install-recommends python3.6 libpython3.6-dev + +# Install pip3.6 +set +e +pip35_version=$(pip3.6 --version | grep "python 3.6") +if [[ -z $pip35_version ]]; then + set -e + wget -q https://bootstrap.pypa.io/get-pip.py + python3.6 get-pip.py + rm -f get-pip.py +fi + +set -e +# Install six. +pip3.6 install --upgrade absl-py +pip3.6 install --upgrade six==1.10.0 + +# Install protobuf. +pip3.6 install --upgrade protobuf==3.3.0 + +# Remove obsolete version of six, which can sometimes confuse virtualenv. +rm -rf /usr/lib/python3/dist-packages/six* + +# Install numpy, scipy and scikit-learn required by the builds + +# numpy needs to be installed from source to fix segfaults. See: +# https://github.com/tensorflow/tensorflow/issues/6968 +# This workaround isn't needed for Ubuntu 16.04 or later. +pip3.6 install --no-binary=:all: --upgrade numpy==1.12.0 + +pip3.6 install scipy==0.18.1 + +pip3.6 install scikit-learn==0.18.1 + +# pandas required by `inflow` +pip3 install pandas==0.19.2 + +# Install recent-enough version of wheel for Python 3.6 wheel builds +pip3.6 install wheel==0.29.0 + +pip3.6 install portpicker + +pip3.6 install werkzeug + +pip3.6 install grpcio + +# LINT.ThenChange(//tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh) diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py index cc46dd5162b..3677aaa886f 100644 --- a/tensorflow/tools/pip_package/pip_smoke_test.py +++ b/tensorflow/tools/pip_package/pip_smoke_test.py @@ -66,6 +66,9 @@ BLACKLIST = [ "//tensorflow/contrib/timeseries/examples:data/period_trend.csv", # pylint:disable=line-too-long "//tensorflow/contrib/timeseries/python/timeseries:test_utils", "//tensorflow/contrib/timeseries/python/timeseries/state_space_models:test_utils", # pylint:disable=line-too-long + + # TODO(yifeif): Remove when py_library(testonly=1) is ignored. + "//tensorflow/contrib/summary:summary_test_internal", ] diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 19e1deb95da..8e62228c1b7 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -152,7 +152,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "mkl", urls = [ "https://mirror.bazel.build/github.com/01org/mkl-dnn/releases/download/v0.9/mklml_lnx_2018.0.20170720.tgz", - # "https://github.com/01org/mkl-dnn/releases/download/v0.9/mklml_lnx_2018.0.20170720.tgz", + "https://github.com/01org/mkl-dnn/releases/download/v0.9/mklml_lnx_2018.0.20170720.tgz", ], sha256 = "57ba56c4c243f403ff78f417ff854ef50b9eddf4a610a917b7c95e7fa8553a4b", strip_prefix = "mklml_lnx_2018.0.20170720", @@ -211,7 +211,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "libxsmm_archive", urls = [ "https://mirror.bazel.build/github.com/hfp/libxsmm/archive/1.8.1.tar.gz", - # "https://github.com/hfp/libxsmm/archive/1.8.1.tar.gz", + "https://github.com/hfp/libxsmm/archive/1.8.1.tar.gz", ], sha256 = "2ade869c3f42f23b5263c7d594aa3c7e5e61ac6a3afcaf5d6e42899d2a7986ce", strip_prefix = "libxsmm-1.8.1", @@ -238,7 +238,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "com_googlesource_code_re2", urls = [ "https://mirror.bazel.build/github.com/google/re2/archive/b94b7cd42e9f02673cd748c1ac1d16db4052514c.tar.gz", - # "https://github.com/google/re2/archive/b94b7cd42e9f02673cd748c1ac1d16db4052514c.tar.gz", + "https://github.com/google/re2/archive/b94b7cd42e9f02673cd748c1ac1d16db4052514c.tar.gz", ], sha256 = "bd63550101e056427c9e7ff12a408c1c8b74e9803f393ca916b2926fc2c4906f", strip_prefix = "re2-b94b7cd42e9f02673cd748c1ac1d16db4052514c", @@ -247,8 +247,8 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "gemmlowp", urls = [ - "https://mirror.bazel.build/github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip" - # "https://github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip", + "https://mirror.bazel.build/github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip", + "https://github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip", ], sha256 = "dd2557072bde12141419cb8320a9c25e6ec41a8ae53c2ac78c076a347bb46d9d", strip_prefix = "gemmlowp-010bb3e71a26ca1d0884a167081d092b43563996", @@ -258,7 +258,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "farmhash_archive", urls = [ "https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz", - # "https://github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz", + "https://github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz", ], sha256 = "6560547c63e4af82b0f202cb710ceabb3f21347a4b996db565a411da5b17aba0", strip_prefix = "farmhash-816a4ae622e964763ca0862d9dbd19324a1eaf45", @@ -274,7 +274,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "highwayhash", urls = [ "https://mirror.bazel.build/github.com/google/highwayhash/archive/dfcb97ca4fe9277bf9dc1802dd979b071896453b.tar.gz", - # "https://github.com/google/highwayhash/archive/dfcb97ca4fe9277bf9dc1802dd979b071896453b.tar.gz", + "https://github.com/google/highwayhash/archive/dfcb97ca4fe9277bf9dc1802dd979b071896453b.tar.gz", ], sha256 = "0f30a15b1566d93f146c8d149878a06e91d9bb7ec2cfd76906df62a82be4aac9", strip_prefix = "highwayhash-dfcb97ca4fe9277bf9dc1802dd979b071896453b", @@ -296,7 +296,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "jpeg", urls = [ "https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.1.tar.gz", - # "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.1.tar.gz", + "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.1.tar.gz", ], sha256 = "c15a9607892113946379ccea3ca8b85018301b200754f209453ab21674268e77", strip_prefix = "libjpeg-turbo-1.5.1", @@ -308,7 +308,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "png_archive", urls = [ "https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.2.53.tar.gz", - # "https://github.com/glennrp/libpng/archive/v1.2.53.tar.gz", + "https://github.com/glennrp/libpng/archive/v1.2.53.tar.gz", ], sha256 = "716c59c7dfc808a4c368f8ada526932be72b2fcea11dd85dc9d88b1df1dfe9c2", strip_prefix = "libpng-1.2.53", @@ -351,6 +351,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "absl_py", urls = [ + "https://mirror.bazel.build/github.com/abseil/abseil-py/archive/231e3870b976c1dc61dce1749138661d21556028.tar.gz", "https://github.com/abseil/abseil-py/archive/231e3870b976c1dc61dce1749138661d21556028.tar.gz", ], sha256 = "8ea2b23bfdb9ae7622f3e5d95236bc600c8d8509a2f38c84732b3145585d4f73", @@ -372,7 +373,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "com_github_andreif_codegen", urls = [ "https://mirror.bazel.build/github.com/andreif/codegen/archive/1.0.tar.gz", - # "https://github.com/andreif/codegen/archive/1.0.tar.gz", + "https://github.com/andreif/codegen/archive/1.0.tar.gz", ], sha256 = "2dadd04a2802de27e0fe5a19b76538f6da9d39ff244036afa00c1bba754de5ee", strip_prefix = "codegen-1.0", @@ -395,12 +396,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): actual = "@six_archive//:six", ) - # TODO(gunan): Add github mirror back if/when sha256sum issues are resolved. - # See https://github.com/libgit2/libgit2/issues/4343 for contetxt. patched_http_archive( name = "protobuf_archive", urls = [ "https://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", + "https://github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", ], sha256 = "e178a25c52efcb6b05988bdbeace4c0d3f2d2fe5b46696d1d9898875c3803d6a", strip_prefix = "protobuf-b04e5cba356212e4e8c66c61bbe0c3a20537c5b9", @@ -424,31 +424,31 @@ def tf_workspace(path_prefix="", tf_repo_name=""): # We need to import the protobuf library under the names com_google_protobuf # and com_google_protobuf_cc to enable proto_library support in bazel. # Unfortunately there is no way to alias http_archives at the moment. - # TODO(gunan): Add github mirror back if/when sha256sum issues are resolved. native.http_archive( name = "com_google_protobuf", urls = [ - "https://mirror.bazel.build/github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", + "https://github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", ], - sha256 = "6d43b9d223ce09e5d4ce8b0060cb8a7513577a35a64c7e3dad10f0703bf3ad93", - strip_prefix = "protobuf-0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66", + sha256 = "e178a25c52efcb6b05988bdbeace4c0d3f2d2fe5b46696d1d9898875c3803d6a", + strip_prefix = "protobuf-b04e5cba356212e4e8c66c61bbe0c3a20537c5b9", ) - # TODO(gunan): Add github mirror back if/when sha256sum issues are resolved. native.http_archive( name = "com_google_protobuf_cc", urls = [ - "https://mirror.bazel.build/github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", + "https://github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", ], - sha256 = "6d43b9d223ce09e5d4ce8b0060cb8a7513577a35a64c7e3dad10f0703bf3ad93", - strip_prefix = "protobuf-0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66", + sha256 = "e178a25c52efcb6b05988bdbeace4c0d3f2d2fe5b46696d1d9898875c3803d6a", + strip_prefix = "protobuf-b04e5cba356212e4e8c66c61bbe0c3a20537c5b9", ) native.http_archive( name = "nsync", urls = [ "https://mirror.bazel.build/github.com/google/nsync/archive/93815892dddafe9146a5f7e7042281d59d0f4323.tar.gz", - # "https://github.com/google/nsync/archive/93815892dddafe9146a5f7e7042281d59d0f4323.tar.gz", + "https://github.com/google/nsync/archive/93815892dddafe9146a5f7e7042281d59d0f4323.tar.gz", ], sha256 = "e3bd4555415ace511338fc27e595351738eea4e9006f1612b76c82914770716b", strip_prefix = "nsync-93815892dddafe9146a5f7e7042281d59d0f4323", @@ -458,7 +458,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "com_google_googletest", urls = [ "https://mirror.bazel.build/github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip", - # "https://github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip", + "https://github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip", ], sha256 = "9cbca84c4256bed17df2c8f4d00c912c19d247c11c9ba6647cd6dd5b5c996b8d", strip_prefix = "googletest-9816b96a6ddc0430671693df90192bbee57108b6", @@ -468,7 +468,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "com_github_gflags_gflags", urls = [ "https://mirror.bazel.build/github.com/gflags/gflags/archive/f8a0efe03aa69b3336d8e228b37d4ccb17324b88.tar.gz", - # "https://github.com/gflags/gflags/archive/f8a0efe03aa69b3336d8e228b37d4ccb17324b88.tar.gz", + "https://github.com/gflags/gflags/archive/f8a0efe03aa69b3336d8e228b37d4ccb17324b88.tar.gz", ], sha256 = "4d222fab8f1ede4709cdff417d15a1336f862d7334a81abf76d09c15ecf9acd1", strip_prefix = "gflags-f8a0efe03aa69b3336d8e228b37d4ccb17324b88", @@ -536,11 +536,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "grpc", urls = [ - # "https://mirror.bazel.build/github.com/grpc/grpc/archive/54e8f37e537794c2d814c1604c1282125f64f093.tar.gz", + "https://mirror.bazel.build/github.com/grpc/grpc/archive/54e8f37e537794c2d814c1604c1282125f64f093.tar.gz", "https://github.com/grpc/grpc/archive/54e8f37e537794c2d814c1604c1282125f64f093.tar.gz", ], - sha256 = "c2166b6d96daddf72fe45b2c594210c65ca17ec3c1b2e12089159a9529edb5e4", - strip_prefix = "grpc-54e8f37e537794c2d814c1604c1282125f64f093", + sha256 = "c2166b6d96daddf72fe45b2c594210c65ca17ec3c1b2e12089159a9529edb5e4", + strip_prefix = "grpc-54e8f37e537794c2d814c1604c1282125f64f093", ) # gRPC wants the existence of a cares dependence but its contents are not @@ -567,7 +567,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "7f51f45887a3d31b4ce4fa5965210a5e64637ceac12720cfce7954d6a2e812f7", urls = [ "https://mirror.bazel.build/github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz", - # "https://github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz", + "https://github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz", ], strip_prefix = "linenoise-c894b9e59f02203dbe4e2be657572cf88c4230c3", build_file = str(Label("//third_party:linenoise.BUILD")), @@ -578,11 +578,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): temp_workaround_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/618cf290880ae9cd87b4bbf6c9b1759476f422eb.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/618cf290880ae9cd87b4bbf6c9b1759476f422eb.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/823bedeb8e23a095173389fa05680597eba3f569.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/823bedeb8e23a095173389fa05680597eba3f569.tar.gz", ], - sha256 = "ec2e032e58372c614c41b539c0309baa91843c30d7a9c6dee647dcd24be02e3c", - strip_prefix = "llvm-618cf290880ae9cd87b4bbf6c9b1759476f422eb", + sha256 = "93464bc760fd0319ebd0a5831fe477fdc4954f3612a29cc64d7405eaee8e00b2", + strip_prefix = "llvm-823bedeb8e23a095173389fa05680597eba3f569", build_file = str(Label("//third_party/llvm:llvm.BUILD")), repository = tf_repo_name, ) @@ -591,7 +591,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "lmdb", urls = [ "https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz", - # "https://github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz", + "https://github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz", ], sha256 = "108532fb94c6f227558d45be3f3347b52539f0f58290a7bb31ec06c462d05326", strip_prefix = "lmdb-LMDB_0.9.19/libraries/liblmdb", @@ -602,7 +602,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "jsoncpp_git", urls = [ "https://mirror.bazel.build/github.com/open-source-parsers/jsoncpp/archive/11086dd6a7eba04289944367ca82cea71299ed70.tar.gz", - # "https://github.com/open-source-parsers/jsoncpp/archive/11086dd6a7eba04289944367ca82cea71299ed70.tar.gz", + "https://github.com/open-source-parsers/jsoncpp/archive/11086dd6a7eba04289944367ca82cea71299ed70.tar.gz", ], sha256 = "07d34db40593d257324ec5fb9debc4dc33f29f8fb44e33a2eeb35503e61d0fe2", strip_prefix = "jsoncpp-11086dd6a7eba04289944367ca82cea71299ed70", @@ -618,6 +618,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "boringssl", urls = [ "https://mirror.bazel.build/github.com/google/boringssl/archive/a0fb951d2a26a8ee746b52f3ba81ab011a0af778.tar.gz", + "https://github.com/google/boringssl/archive/a0fb951d2a26a8ee746b52f3ba81ab011a0af778.tar.gz", ], sha256 = "524ba98a56300149696481b4cb9ddebd0c7b7ac9b9f6edee81da2d2d7e5d2bb3", strip_prefix = "boringssl-a0fb951d2a26a8ee746b52f3ba81ab011a0af778", @@ -653,7 +654,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "snappy", urls = [ "https://mirror.bazel.build/github.com/google/snappy/archive/1.1.4.tar.gz", - # "https://github.com/google/snappy/archive/1.1.4.tar.gz", + "https://github.com/google/snappy/archive/1.1.4.tar.gz", ], sha256 = "2f7504c73d85bac842e893340333be8cb8561710642fc9562fccdd9d2c3fcc94", strip_prefix = "snappy-1.1.4", @@ -665,7 +666,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "nccl_archive", urls = [ "https://mirror.bazel.build/github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz", - # "https://github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz", + "https://github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz", ], sha256 = "2ca86fb6179ecbff789cc67c836139c1bbc0324ed8c04643405a30bf26325176", strip_prefix = "nccl-03d856977ecbaac87e598c0c4bafca96761b9ac7", @@ -676,8 +677,8 @@ def tf_workspace(path_prefix="", tf_repo_name=""): temp_workaround_http_archive( name = "aws", urls = [ - "http://bazel-mirror.storage.googleapis.com/github.com/aws/aws-sdk-cpp/archive/1.0.90.tar.gz", - # "https://github.com/aws/aws-sdk-cpp/archive/1.0.90.tar.gz", + "https://mirror.bazel.build/github.com/aws/aws-sdk-cpp/archive/1.0.90.tar.gz", + "https://github.com/aws/aws-sdk-cpp/archive/1.0.90.tar.gz", ], sha256 = "f599b57aec4f03ad696044dd430b2d201864113937353adc346f53ad47991319", strip_prefix = "aws-sdk-cpp-1.0.90", @@ -714,7 +715,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "jemalloc", urls = [ "https://mirror.bazel.build/github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz", - # "https://github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz", + "https://github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz", ], sha256 = "3c8f25c02e806c3ce0ab5fb7da1817f89fc9732709024e2a81b6b82f7cc792a8", strip_prefix = "jemalloc-4.4.0", @@ -761,7 +762,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "com_google_pprof", urls = [ "https://mirror.bazel.build/github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz", - # "https://github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz", + "https://github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz", ], sha256 = "e0928ca4aa10ea1e0551e2d7ce4d1d7ea2d84b2abbdef082b0da84268791d0c4", strip_prefix = "pprof-c0fb62ec88c411cc91194465e54db2632845b650", @@ -772,7 +773,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "cub_archive", urls = [ "https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.4.zip", - # "https://github.com/NVlabs/cub/archive/1.7.4.zip", + "https://github.com/NVlabs/cub/archive/1.7.4.zip", ], sha256 = "20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31", strip_prefix = "cub-1.7.4", @@ -799,7 +800,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "bazel_toolchains", urls = [ "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/af4681c3d19f063f090222ec3d04108c4e0ca255.tar.gz", - # "https://github.com/bazelbuild/bazel-toolchains/archive/af4681c3d19f063f090222ec3d04108c4e0ca255.tar.gz", + "https://github.com/bazelbuild/bazel-toolchains/archive/af4681c3d19f063f090222ec3d04108c4e0ca255.tar.gz", ], sha256 = "d58bb2d6c8603f600d522b6104d6192a65339aa26cbba9f11ff5c4b36dedb928", strip_prefix = "bazel-toolchains-af4681c3d19f063f090222ec3d04108c4e0ca255", @@ -832,6 +833,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:tflite_mobilenet.BUILD")), sha256 = "23f814d1c076bdf03715dfb6cab3713aa4fbdf040fd5448c43196bd2e97a4c1b", urls = [ - "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip" + "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip", + "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip", ], ) diff --git a/third_party/llvm/llvm.BUILD b/third_party/llvm/llvm.BUILD index 97b833e49d5..5344525ba8b 100644 --- a/third_party/llvm/llvm.BUILD +++ b/third_party/llvm/llvm.BUILD @@ -7,18 +7,18 @@ licenses(["notice"]) exports_files(["LICENSE.TXT"]) load( - "@%ws%//third_party/llvm:llvm.bzl", + "@org_tensorflow//third_party/llvm:llvm.bzl", "gentbl", "expand_cmake_vars", "llvm_target_cmake_vars", "cmake_var_string", ) load( - "@%ws%//third_party:common.bzl", + "@org_tensorflow//third_party:common.bzl", "template_rule", ) -package(default_visibility = ["@%ws%//tensorflow/compiler/xla:internal"]) +package(default_visibility = ["//visibility:public"]) llvm_host_triple = "x86_64-unknown-linux_gnu" @@ -145,11 +145,11 @@ darwin_cmake_vars = { # TODO(phawkins): use a better method to select the right host triple, rather # than hardcoding x86_64. all_cmake_vars = select({ - "@%ws%//tensorflow:darwin": cmake_var_string( + "@org_tensorflow//tensorflow:darwin": cmake_var_string( cmake_vars + llvm_target_cmake_vars("X86", "x86_64-apple-darwin") + darwin_cmake_vars, ), - "@%ws%//tensorflow:linux_ppc64le": cmake_var_string( + "@org_tensorflow//tensorflow:linux_ppc64le": cmake_var_string( cmake_vars + llvm_target_cmake_vars("PowerPC", "powerpc64le-unknown-linux_gnu") + linux_cmake_vars,