diff --git a/RELEASE.md b/RELEASE.md index ae41d56e147..763ef3b279d 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -3,7 +3,7 @@ ## Major Features And Improvements * The `tf.lite` runtime now supports `complex64`. -* Initial Bigtable integration for `tf.data`. +* Initial [Google Cloud Bigtable integration](https://github.com/tensorflow/tensorflow/tree/r1.10/tensorflow/contrib/bigtable) for `tf.data`. * Improved local run behavior in `tf.estimator.train_and_evaluate` which does not reload checkpoints for evaluation. * `RunConfig` now sets device_filters to restrict how workers and PS can communicate. This can speed up training and ensure clean shutdowns in some situations. But if you have jobs that require communication between workers, you will have to set custom session_options in your `RunConfig`. * Moved Distributions and Bijectors from `tf.contrib.distributions` to [Tensorflow Probability (TFP)](https://github.com/tensorflow/probability). `tf.contrib.distributions` is now deprecated and will be removed by the end of 2018. diff --git a/configure.py b/configure.py index 6d0c0774068..7acc6932eb2 100644 --- a/configure.py +++ b/configure.py @@ -839,15 +839,16 @@ def set_tf_cuda_version(environ_cp): cuda_toolkit_path = cygpath(cuda_toolkit_path) if is_windows(): - cuda_rt_lib_path = 'lib/x64/cudart.lib' + cuda_rt_lib_paths = ['lib/x64/cudart.lib'] elif is_linux(): - cuda_rt_lib_path = 'lib64/libcudart.so.%s' % tf_cuda_version + cuda_rt_lib_paths = ['%s/libcudart.so.%s' % (x, tf_cuda_version) + for x in ['lib64', 'lib/x86_64-linux-gnu']] elif is_macos(): - cuda_rt_lib_path = 'lib/libcudart.%s.dylib' % tf_cuda_version + cuda_rt_lib_paths = ['lib/libcudart.%s.dylib' % tf_cuda_version] - cuda_toolkit_path_full = os.path.join(cuda_toolkit_path, cuda_rt_lib_path) - if os.path.exists(cuda_toolkit_path_full): - break + cuda_toolkit_paths_full = [os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths] + if any([os.path.exists(x) for x in cuda_toolkit_paths_full]): + break # Reset and retry print('Invalid path to CUDA %s toolkit. %s cannot be found' % @@ -1398,10 +1399,6 @@ def set_grpc_build_flags(): write_to_bazelrc('build --define grpc_no_ares=true') -def set_build_strip_flag(): - write_to_bazelrc('build --strip=always') - - def set_windows_build_flags(environ_cp): """Set Windows specific build options.""" # The non-monolithic build is not supported yet @@ -1560,7 +1557,6 @@ def main(): set_grpc_build_flags() set_cc_opt_flags(environ_cp) - set_build_strip_flag() if is_windows(): set_windows_build_flags(environ_cp) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index f3d8d558ac8..e5654a5141d 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -124,12 +124,6 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( - name = "windows_msvc", - values = {"cpu": "x64_windows_msvc"}, - visibility = ["//visibility:public"], -) - config_setting( name = "no_tensorflow_py_deps", define_values = {"no_tensorflow_py_deps": "true"}, @@ -439,14 +433,14 @@ package_group( load( "//third_party/mkl:build_defs.bzl", - "if_mkl", + "if_mkl_ml", ) filegroup( name = "intel_binary_blob", - data = if_mkl( + data = if_mkl_ml( [ - "//third_party/mkl:intel_binary_blob", + "//third_party/intel_mkl_ml", ], ), ) @@ -497,7 +491,6 @@ tf_cc_shared_object( linkopts = select({ "//tensorflow:darwin": [], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "-Wl,--version-script", # This line must be directly followed by the version_script.lds file "$(location //tensorflow:tf_framework_version_script.lds)", @@ -539,7 +532,6 @@ tf_cc_shared_object( "-Wl,-install_name,@rpath/libtensorflow.so", ], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file @@ -564,7 +556,6 @@ tf_cc_shared_object( "$(location //tensorflow:tf_exported_symbols.lds)", ], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index f4be60a183b..f56521dac03 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -628,7 +628,6 @@ tf_cc_binary( copts = tf_copts(), linkopts = select({ "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//tensorflow:darwin": [ "-lm", "-lpthread", diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 55b98da4720..e059f77563b 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -314,12 +314,16 @@ cc_library( "deadness_analysis_internal.h", "encapsulate_subgraphs_pass.cc", "mark_for_compilation_pass.cc", + "mark_for_compilation_pass_test_helper.cc", + "partially_decluster_pass.cc", ], hdrs = [ "build_xla_launch_ops_pass.h", "deadness_analysis.h", "encapsulate_subgraphs_pass.h", "mark_for_compilation_pass.h", + "mark_for_compilation_pass_test_helper.h", + "partially_decluster_pass.h", ], deps = [ ":common", @@ -354,6 +358,7 @@ cc_library( "//tensorflow/compiler/jit/graphcycles", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", ], @@ -418,10 +423,12 @@ tf_cc_test( srcs = [ "encapsulate_subgraphs_pass_test.cc", "mark_for_compilation_pass_test.cc", + "partially_decluster_pass_test.cc", ], deps = [ ":common", ":compilation_passes", + ":xla_cluster_util", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 4d49a14b24d..c37b6112cc8 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" +#include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "tensorflow/core/common_runtime/optimization_registry.h" namespace tensorflow { @@ -23,15 +24,18 @@ namespace tensorflow { REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20, + PartiallyDeclusterPass); + // The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We // also need to run it after the graph been rewritten to have _Send nodes added // for fetches. Before the _Send nodes are added, fetch nodes are identified by // name, and encapsulation might remove that node from the graph. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20, +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, EncapsulateSubgraphsPass); // Must run after EncapsulateSubgraphsPass. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40, BuildXlaLaunchOpsPass); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 37a2f3b5ac9..7f4370b5b07 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -210,7 +210,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; - launch_context.PopulateOutputs(ctx, kernel, run_result.ConsumeValueOrDie()); + OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs( + ctx, kernel, run_result.ConsumeValueOrDie())); VLOG(1) << "Done"; } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 45d422943c2..90d5d56998c 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -65,6 +65,7 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave // such nodes out of XLA clusters. if (HasForwardedRefInput(node)) { + VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast."; return false; } @@ -84,14 +85,13 @@ bool IsCompilableCall(const NodeDef& call_def, bool IsCompilableWhile(const Node& while_node, const DeviceType& jit_device_type, int depth, FunctionLibraryRuntime* lib_runtime) { - VLOG(2) << "Loop marking: " << while_node.type_string(); - const NameAttrList* name_attr; NodeDef call; Status status; status = GetNodeAttr(while_node.attrs(), "cond", &name_attr); if (!status.ok()) { - VLOG(2) << "Missing 'cond' attribute on While node."; + VLOG(2) << "Rejecting While " << while_node.name() + << ": missing 'cond' attribute on While node."; return false; } const string cond_func = name_attr->name(); @@ -99,12 +99,14 @@ bool IsCompilableWhile(const Node& while_node, call.set_op(cond_func); *call.mutable_attr() = name_attr->attr(); if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) { - VLOG(2) << "Can't compile loop condition: " << cond_func; + VLOG(2) << "Rejecting While " << while_node.name() + << ": can't compile loop condition: " << cond_func; return false; } status = GetNodeAttr(while_node.attrs(), "body", &name_attr); if (!status.ok()) { - VLOG(2) << "Missing 'body' attribute on While node."; + VLOG(2) << "Rejecting While " << while_node.name() + << ": missing 'body' attribute on While node."; return false; } const string body_func = name_attr->name(); @@ -112,10 +114,10 @@ bool IsCompilableWhile(const Node& while_node, call.set_op(body_func); *call.mutable_attr() = name_attr->attr(); if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) { - VLOG(2) << "Can't compile loop body: " << body_func; + VLOG(2) << "Rejecting While " << while_node.name() + << ": can't compile loop body: " << body_func; return false; } - VLOG(2) << "Loop is compilable."; return true; } @@ -125,10 +127,9 @@ bool IsCompilableWhile(const Node& while_node, bool IsCompilableCall(const NodeDef& call_def, const DeviceType& jit_device_type, int depth, FunctionLibraryRuntime* lib_runtime) { - VLOG(2) << "Function marking: " << call_def.op(); - if (depth > kMaxRecursionDepth) { - VLOG(2) << "Function depth limit exceeded"; + VLOG(2) << "Rejecting " << call_def.op() + << ": function depth limit exceeded."; return false; } @@ -136,7 +137,8 @@ bool IsCompilableCall(const NodeDef& call_def, Status status = lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle); if (!status.ok()) { - VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status; + VLOG(2) << "Rejecting " << call_def.op() + << ": could not instantiate: " << status; return false; } const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle); @@ -150,7 +152,8 @@ bool IsCompilableCall(const NodeDef& call_def, // tf2xla to translate the TF graph into XLA. So we avoid this for now. // // TODO(b/36139787): Create a mechanism to set inlining hints. - VLOG(2) << "Can't compile noinline function: " << fdef.DebugString(); + VLOG(2) << "Rejecting " << call_def.op() + << ": can't compile noinline function."; return false; } @@ -164,23 +167,14 @@ bool IsCompilableCall(const NodeDef& call_def, if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, depth + 1, lib_runtime)) { - VLOG(2) << "Function marking failed: unsupported op " << node->name() - << ": " << node->def().ShortDebugString(); + VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op " + << node->name() << ": " << node->def().ShortDebugString(); return false; } } - VLOG(2) << "Function is compilable: " << call_def.op(); return true; } -// Tests whether `node` has a DT_RESOURCE typed input or output. -bool HasResourceInputOrOutput(const Node& node) { - return std::find(node.input_types().begin(), node.input_types().end(), - DT_RESOURCE) != node.input_types().end() || - std::find(node.output_types().begin(), node.output_types().end(), - DT_RESOURCE) != node.output_types().end(); -} - // Returns true if the op can be decomposed into XLA ops for which // there are fusable elemental implementations. // @@ -357,24 +351,27 @@ Status FindCompilationCandidates( } std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID()); + if (fuel >= std::numeric_limits::max() / 2) { + // The assumption is that if fuel started out as INT64_MAX, it will forever + // stay greater than INT64_MAX / 2. + VLOG(2) << "Starting fuel: infinity"; + } else { + VLOG(2) << "Starting fuel: " << fuel; + } + for (Node* node : sorted_nodes) { - VLOG(2) << "Fuel: " << fuel; if (fuel <= 0) { - VLOG(2) + VLOG(1) << "Hit fuel limit; not marking any remaining ops as clusterable."; break; } - VLOG(2) << "FindCompilationCandidates(): Processing " - << node->DebugString(); - DeviceType device_type(""); TF_RETURN_IF_ERROR( DeviceToDeviceType(node->assigned_device_name(), &device_type)); if (is_compilable_fn && !is_compilable_fn(node, device_type)) { - VLOG(2) << "Compilation rejected node: not compilable " << node->name() - << ": " << node->type_string(); + // is_compilable_fn has already logged the reason if it returned false. continue; } @@ -384,14 +381,14 @@ Status FindCompilationCandidates( DeviceType jit_device_type(registration->compilation_device_name); if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime)) { - VLOG(2) << "Compilation rejected node: unsupported op " << node->name() - << ": " << node->type_string(); + VLOG(2) << "Rejecting " << node->name() << ": unsupported op " + << node->type_string(); continue; } if (!registration->compile_resource_ops && HasResourceInputOrOutput(*node)) { - VLOG(2) << "Compilation rejected node: resource input/output " - << node->name() << ": " << node->type_string(); + VLOG(2) << "Rejecting: " << node->name() << ": resource input/output " + << node->type_string(); continue; } if (node->type_string() == "While" && @@ -401,15 +398,11 @@ Status FindCompilationCandidates( // _Arg nodes in a top-level function represent feeds. // Do not compile them. if (node->type_string() == "_Arg") { - VLOG(2) << "Skipping jit compilation for '_Arg'-typed node " - << node->DebugString(); continue; } // _Retval nodes in a top-level function represent fetches. // Do not compile them. if (node->type_string() == "_Retval") { - VLOG(2) << "Compilation rejected node: return value " << node->name() - << ": " << node->type_string(); continue; } candidates->insert(node); @@ -475,6 +468,7 @@ Status MarkForCompilationPass::Run( const XlaOpRegistry::DeviceRegistration* registration; if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + VLOG(2) << "Rejecting " << node->name() << ": could not find JIT device."; return false; } @@ -484,21 +478,36 @@ Status MarkForCompilationPass::Run( // If there is a _XlaCompile annotation, use its value. bool compile = false; Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); - if (status.ok()) return compile; + if (status.ok()) { + if (!compile) { + VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr(" + << kXlaCompileAttr << ") is false."; + } + return compile; + } status = fld->GetAttr(*node, kXlaCompileAttr, &compile); - if (status.ok()) return compile; + if (status.ok()) { + if (!compile) { + VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr(" + << kXlaCompileAttr << ") on callee is false."; + } + return compile; + } // If inputs to `node` can have conflicting deadness (i.e. some are alive // and some are dead) then don't compile it. XLA cannot represent the // deadness semantics of these nodes correctly and auto-clustering these // nodes can cause deadness to propagate to nodes that should be live. if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) { + VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness."; return false; } // Check for fusable ops only if requested. if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) { + VLOG(2) << "Rejecting " << node->name() + << ": not fusable op but fusion_only enabled."; return false; } @@ -506,8 +515,17 @@ Status MarkForCompilationPass::Run( // Ignore enable_jit_by_default if global jit compilation for CPU // is explicitly requested via tf_xla_cpu_global_jit flag bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU; - return (ignore_registration || registration->enable_jit_by_default) && - global_jit_level > 0; + bool should_compile = + (ignore_registration || registration->enable_jit_by_default) && + global_jit_level > 0; + if (!should_compile) { + if (global_jit_level <= 0) { + VLOG(2) << "Rejecting " << node->name() << ": global jit disabled."; + } else { + VLOG(2) << "Rejecting " << node->name() << ": JIT for device disabled."; + } + } + return should_compile; }; return RunImpl(options, is_compilable); } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index e9acbfb19e4..f1137af3c1e 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -40,20 +40,18 @@ class MarkForCompilationPass : public GraphOptimizationPass { Status Run(const GraphOptimizationPassOptions& options) override; - // Run() just calls RunImpl() if --tf_xla_auto_jit is enabled. To run the pass - // unconditionally, call RunImpl() directly. - // is_compilable_fn, if set, is a predicate that must be true for a node to - // be compiled. + private: Status RunImpl(const GraphOptimizationPassOptions& options, const std::function& is_compilable_fn = {}); + + friend class MarkForCompilationPassTestHelper; }; // Returns true iff 'ndef' is a call to a function that is compilable. A // function is compilable iff every operator in the function body is // compilable. bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef); - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 2c5f4fb774f..a780d4a936a 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" +#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" @@ -39,27 +39,6 @@ namespace { REGISTER_OP("UncompilableNullary").Output("o: float"); REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float"); -Status MarkForCompilation(std::unique_ptr* graph, - FunctionLibraryDefinition* flib_def) { - // Assign all nodes to the CPU device. - static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; - for (Node* n : (*graph)->nodes()) { - n->set_assigned_device_name(kCpuDevice); - } - - GraphOptimizationPassOptions opt_options; - opt_options.graph = graph; - opt_options.flib_def = flib_def; - MarkForCompilationPass pass; - return pass.RunImpl(opt_options); -} - -Status MarkForCompilation(std::unique_ptr* graph) { - FunctionDefLibrary flib; - FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib); - return MarkForCompilation(graph, &flib_def); -} - std::unordered_map GetClusters(const Graph& graph) { std::unordered_map ids; for (Node* node : graph.nodes()) { @@ -88,7 +67,7 @@ TEST(XlaCompilationTest, Chains) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(4, clusters.size()); EXPECT_EQ(clusters["B"], clusters["C"]); @@ -113,7 +92,7 @@ TEST(XlaCompilationTest, UncompilableCycles) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_TRUE(clusters.empty()); @@ -133,7 +112,7 @@ TEST(XlaCompilationTest, CompilableCycles) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(3, clusters.size()); @@ -156,7 +135,7 @@ TEST(XlaCompilationTest, Complex128Unsupported) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_TRUE(clusters.empty()); } @@ -177,7 +156,7 @@ TEST(XlaCompilationTest, HalfSupported) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_FALSE(clusters.empty()); } @@ -206,7 +185,7 @@ TEST(XlaCompilationTest, ConcatWithConstArg) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(3, clusters.size()); // Everything should be compiled. } @@ -241,7 +220,8 @@ TEST(XlaCompilationTest, FunctionCalls) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph, &flib_def)); + TF_ASSERT_OK( + MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def)); auto clusters = GetClusters(*graph); EXPECT_EQ(2, clusters.size()); @@ -272,7 +252,7 @@ TEST(XlaCompilationTest, MetadataOpsDontStartClusters) { ops::UnaryOp("Shape", d, builder.opts().WithName("E")); TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(0, clusters.size()); // Nothing should be compiled. } @@ -359,7 +339,7 @@ TEST(XlaCompilationTest, SymbolicGradients) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(2, clusters.size()); @@ -384,7 +364,7 @@ TEST(XlaCompilationTest, Loops) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_EXPECT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); // Nothing should be compiled. In particular, 'd' and 'c' must not be @@ -411,7 +391,7 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) { TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); // The computation is: C = A + relu(A) @@ -442,7 +422,7 @@ TEST(XlaCompilationTest, CyclesWithSplittingScopes) { TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); // The computation is: D = relu(A) + (A @ relu(A)) @@ -472,7 +452,7 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) { TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); // The computation is: C = A @ relu(A) @@ -512,7 +492,7 @@ TEST(XlaCompilationTest, Resources) { ops::UnaryOp("Relu", d, builder.opts().WithName("E")); TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(0, clusters.size()); // Nothing should be compiled. } @@ -542,7 +522,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { TF_EXPECT_OK(root.ToGraph(graph.get())); - Status status = MarkForCompilation(&graph); + Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph); EXPECT_FALSE(status.ok()); EXPECT_TRUE(str_util::StrContains(status.ToString(), "Edge from c to a would create a cycle.\n" @@ -570,7 +550,7 @@ TEST(XlaCompilationTest, Retval) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(2, clusters.size()); @@ -588,7 +568,7 @@ TEST(XlaCompilationTest, DontCountIdentityOps) { auto r = ops::_Retval(root.WithOpName("R"), c, 0); } TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_TRUE(clusters.empty()); @@ -604,7 +584,7 @@ TEST(XlaCompilationTest, DontCountIdentityOpsWithLocalJit) { auto r = ops::_Retval(root.WithOpName("R"), b, 0); } TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_TRUE(clusters.empty()); @@ -618,7 +598,7 @@ TEST(XlaCompilationTest, ConstOp) { auto c = ops::Const(root.WithOpName("const"), 0.5f); c.node()->AddAttr(kXlaCompileAttr, true); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); EXPECT_EQ(1, GetClusters(*graph).size()); } @@ -629,7 +609,7 @@ TEST(XlaCompilationTest, ConstOp) { auto c = ops::Const(root.WithOpName("const"), string("string")); c.node()->AddAttr(kXlaCompileAttr, true); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); EXPECT_TRUE(GetClusters(*graph).empty()); } } @@ -644,7 +624,7 @@ TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::unordered_map clusters = GetClusters(*graph); @@ -667,7 +647,7 @@ TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::unordered_map clusters = GetClusters(*graph); @@ -699,7 +679,7 @@ TEST(XlaCompilationTest, ClusterControlTrigger) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::unordered_map clusters = GetClusters(*graph); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc new file mode 100644 index 00000000000..a84b82e4792 --- /dev/null +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" + +namespace tensorflow { +/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def) { + // Assign all nodes to the CPU device. + static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; + for (Node* n : (*graph)->nodes()) { + n->set_assigned_device_name(kCpuDevice); + } + + GraphOptimizationPassOptions opt_options; + opt_options.graph = graph; + opt_options.flib_def = flib_def; + MarkForCompilationPass pass; + return pass.RunImpl(opt_options); +} + +/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( + std::unique_ptr* graph) { + FunctionDefLibrary flib; + FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib); + return MarkForCompilation(graph, &flib_def); +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h new file mode 100644 index 00000000000..b9a0531cb0e --- /dev/null +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_ +#define TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_ + +#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" + +namespace tensorflow { +class MarkForCompilationPassTestHelper { + public: + // Runs the MarkForCompilation pass on `graph` after assigning all nodes in + // `graph` to the CPU device. To make testing easier, ignores device + // registration, _XlaCompile attributes, input deadness and global jit level. + static Status MarkForCompilation(std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def); + + // Like `MarkForCompilation` but creates `flib_def` from the op registry. + static Status MarkForCompilation(std::unique_ptr* graph); +}; +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_ diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc new file mode 100644 index 00000000000..68ead39424c --- /dev/null +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -0,0 +1,177 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/partially_decluster_pass.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/core/framework/memory_types.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace tensorflow { +namespace { +Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet* result, + gtl::ArraySlice post_order) { + // Find nodes that have at least one user outside their cluster that expects + // hostmem output. These nodes should be cloned to outside the cluster to + // avoid the device-host copy we'd otherwise need. + + MemoryTypeVector input_mtypes, output_mtypes; + + for (Node* n : post_order) { + gtl::optional from_cluster = GetXlaClusterForNode(*n); + if (!from_cluster) { + continue; + } + + // We assume the only XLA-auto-clusterable operations with side effects are + // resource variable updates. We can't execute these twice. + if (HasResourceInputOrOutput(*n)) { + continue; + } + + DeviceType device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(n->assigned_device_name(), &device_type)); + TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type, + n->def(), &input_mtypes, + &output_mtypes)); + for (const Edge* e : n->out_edges()) { + Node* dst = e->dst(); + + if (e->IsControlEdge()) { + continue; + } + + bool edge_incurs_extra_device_to_host_copy; + if (output_mtypes[e->src_output()] == DEVICE_MEMORY) { + // If the output of the *TensorFlow* operation is in DEVICE_MEMORY then + // keep the node clustered -- XLA will also produce the output in device + // memory and we will get some benefit from clustering. + edge_incurs_extra_device_to_host_copy = false; + } else { + MemoryTypeVector dst_input_mtypes, dst_output_mtypes; + DeviceType dst_device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(dst->assigned_device_name(), &dst_device_type)); + TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type, + dst->def(), &dst_input_mtypes, + &dst_output_mtypes)); + edge_incurs_extra_device_to_host_copy = + dst_input_mtypes[e->dst_input()] == HOST_MEMORY; + } + + if (!edge_incurs_extra_device_to_host_copy) { + continue; + } + + // Check if `dst` is in a different cluster, unclustered, or about to be + // partially declustered (here we rely on the post-order traversal order). + // If yes, decluster `n` to avoid the device-to-host memcpy. + gtl::optional dst_cluster = + result->count(dst) ? gtl::nullopt : GetXlaClusterForNode(*dst); + if (from_cluster != dst_cluster) { + CHECK(result->insert(n).second); + break; + } + } + } + return Status::OK(); +} + +Status PartiallyDeclusterNode(Graph* graph, Node* n) { + StringPiece cluster_name = *GetXlaClusterForNode(*n); + gtl::InlinedVector out_edges_to_clone; + for (const Edge* out_edge : n->out_edges()) { + if (out_edge->IsControlEdge()) { + continue; + } + + Node* dst = out_edge->dst(); + gtl::optional dst_cluster_name = GetXlaClusterForNode(*dst); + if (dst_cluster_name != cluster_name) { + out_edges_to_clone.push_back(out_edge); + } + } + + CHECK(!out_edges_to_clone.empty()) << n->DebugString(); + + NodeDef ndef = n->def(); + ndef.set_name(strings::StrCat(n->name(), "/declustered")); + RemoveFromXlaCluster(&ndef); + Status s; + Node* cloned_node = graph->AddNode(ndef, &s); + cloned_node->set_assigned_device_name(n->assigned_device_name()); + TF_RETURN_IF_ERROR(s); + + for (const Edge* in_edge : n->in_edges()) { + graph->AddEdge(in_edge->src(), in_edge->src_output(), cloned_node, + in_edge->dst_input()); + } + + for (const Edge* out_edge_to_clone : out_edges_to_clone) { + graph->AddEdge(cloned_node, out_edge_to_clone->src_output(), + out_edge_to_clone->dst(), out_edge_to_clone->dst_input()); + graph->RemoveEdge(out_edge_to_clone); + } + + return Status::OK(); +} +} // namespace + +Status PartiallyDeclusterPass::Run( + const GraphOptimizationPassOptions& options) { + // NB! In this pass we assume the only XLA-auto-clusterable operations that + // may have side effects are resource variable operations so we don't cluster + // those. The pass will have to be updated if this assumption becomes + // invalid. + + Graph* graph = options.graph->get(); + + // When deciding whether to decluster a particular node, we base our decision + // on if we've decided that some of its consumers have to be declustered too. + // Iterating the graph in post-order guarantees that consumers have been + // visited before producers. + std::vector post_order; + GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(), + /*edge_filter=*/[](const Edge& edge) { + return !edge.src()->IsNextIteration(); + }); + + gtl::FlatSet nodes_to_partially_decluster; + TF_RETURN_IF_ERROR(FindNodesToDecluster( + **options.graph, &nodes_to_partially_decluster, post_order)); + + if (VLOG_IS_ON(3)) { + for (Node* n : post_order) { + if (nodes_to_partially_decluster.count(n)) { + VLOG(3) << n->DebugString(); + } + } + } + + for (Node* n : post_order) { + if (nodes_to_partially_decluster.count(n)) { + TF_RETURN_IF_ERROR(PartiallyDeclusterNode(graph, n)); + } + } + + nodes_to_partially_decluster.clear(); + TF_RETURN_IF_ERROR(FindNodesToDecluster( + **options.graph, &nodes_to_partially_decluster, post_order)); + CHECK(nodes_to_partially_decluster.empty()); + + return Status::OK(); +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass.h b/tensorflow/compiler/jit/partially_decluster_pass.h new file mode 100644 index 00000000000..6949b5028ee --- /dev/null +++ b/tensorflow/compiler/jit/partially_decluster_pass.h @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { + +// Clones nodes from within a cluster to outside the cluster if profitable. +// +// Today this only clones to avoid device-to-host copies, but in the future we +// may consider other reasons to clone. For instance, we convert this: +// +// ..... +// | +// v +// A_Clustered ====> C_Unclustered +// | +// v +// B_Clustered +// +// to: +// +// ..... +// | | +// | +-------------+ +// | | +// v v +// A_Clustered A_Unclustered ====> C_Unclustered +// | +// v +// B_Clustered +// +// where the ===> arrow has a hostmem source and destination and would entail a +// device to host copy if the source and destination were not in the same XLA +// cluster. +class PartiallyDeclusterPass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_ diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc new file mode 100644 index 00000000000..08a956e4c64 --- /dev/null +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -0,0 +1,284 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/partially_decluster_pass.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { +REGISTER_OP("FakeNullary").Output("out: float"); + +REGISTER_OP("FakeBinary") + .Input("host_in: float") + .Input("device_in: float") + .Output("host_out: float") + .Output("device_out: float"); + +REGISTER_OP("FakeResourceVar").Output("out: resource"); + +REGISTER_OP("FakeResourceUpdate") + .Input("in: resource") + .Output("out: resource") + .Output("something_else: float"); + +class FakeBinaryOp : public OpKernel { + public: + explicit FakeBinaryOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { CHECK(false); } +}; + +class FakeResourceVarUpdateOp : public OpKernel { + public: + explicit FakeResourceVarUpdateOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { CHECK(false); } +}; + +REGISTER_KERNEL_BUILDER(Name("FakeBinary") + .Device(DEVICE_CPU) + .HostMemory("host_in") + .HostMemory("host_out"), + FakeBinaryOp); + +REGISTER_KERNEL_BUILDER(Name("FakeResourceVarUpdate") + .Device(DEVICE_CPU) + .HostMemory("something_else"), + FakeResourceVarUpdateOp); + +Status PartiallyDecluster(std::unique_ptr* graph) { + FixupSourceAndSinkEdges(graph->get()); + // Assign all nodes to the CPU device. + static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; + for (Node* n : (*graph)->nodes()) { + n->set_assigned_device_name(kCpuDevice); + } + + GraphOptimizationPassOptions opt_options; + opt_options.graph = graph; + PartiallyDeclusterPass pass; + return pass.Run(opt_options); +} + +const Node* FindNodeByName(const Graph& graph, const string& name) { + for (const Node* node : graph.nodes()) { + if (node->name() == name) { + return node; + } + } + return nullptr; +} + +bool GetInputsForNode(const Graph& graph, const string& node_name, + std::vector* inputs) { + const Node* node = FindNodeByName(graph, node_name); + if (node == nullptr) { + return false; + } + for (const Edge* e : node->in_edges()) { + inputs->push_back(e->src()); + } + std::sort(inputs->begin(), inputs->end(), NodeComparatorName()); + return true; +} + +TEST(PartiallyDeclusterPassTest, ClusteredAndUnclustered) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName("ClusteredProducer")); + ops::BinaryOp("FakeBinary", clustered_producer, input, + builder.opts().WithName("UnclusteredConsumer")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", {clustered_producer, 1}, input, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector unclustered_consumer_inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "UnclusteredConsumer", + &unclustered_consumer_inputs)); + ASSERT_EQ(unclustered_consumer_inputs.size(), 2); + EXPECT_EQ(unclustered_consumer_inputs[0]->name(), + "ClusteredProducer/declustered"); + EXPECT_EQ(unclustered_consumer_inputs[1]->name(), "Input"); + + std::vector clustered_consumer_inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "ClusteredConsumer", + &clustered_consumer_inputs)); + ASSERT_EQ(clustered_consumer_inputs.size(), 2); + EXPECT_EQ(clustered_consumer_inputs[0]->name(), "ClusteredProducer"); + EXPECT_EQ(clustered_consumer_inputs[1]->name(), "Input"); +} + +TEST(PartiallyDeclusterPassTest, DifferentClusters) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName("ClusteredProducer")); + Node* consumer_in_different_cluster = + ops::BinaryOp("FakeBinary", clustered_producer, input, + builder.opts().WithName("ConsumerInDifferentCluster")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", input, {clustered_producer, 1}, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs)); + ASSERT_EQ(inputs.size(), 2); + EXPECT_EQ(inputs[0]->name(), "ClusteredProducer/declustered"); + EXPECT_EQ(inputs[1]->name(), "Input"); +} + +TEST(PartiallyDeclusterPassTest, DontDeclusterIfUserIsDeviceMem) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName("ClusteredProducer")); + // The first input is hostmem and the second input is devicemem. + Node* consumer_in_different_cluster = + ops::BinaryOp("FakeBinary", input, clustered_producer, + builder.opts().WithName("ConsumerInDifferentCluster")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", input, {clustered_producer, 1}, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs)); + ASSERT_EQ(inputs.size(), 2); + EXPECT_EQ(inputs[0]->name(), "ClusteredProducer"); + EXPECT_EQ(inputs[1]->name(), "Input"); +} + +TEST(PartiallyDeclusterPassTest, DontDuplicateResourceVarOps) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* resource_var = ops::SourceOp("FakeResourceVar", + builder.opts().WithName("ResourceVar")); + Node* clustered_producer = + ops::UnaryOp("FakeResourceUpdate", resource_var, + builder.opts().WithName("ClusteredProducer")); + Node* consumer_in_different_cluster = + ops::BinaryOp("FakeBinary", {clustered_producer, 1}, input, + builder.opts().WithName("ConsumerInDifferentCluster")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", input, {clustered_producer, 1}, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs)); + ASSERT_EQ(inputs.size(), 2); + EXPECT_EQ(inputs[0]->name(), "ClusteredProducer"); + EXPECT_EQ(inputs[1]->name(), "Input"); +} + +TEST(PartiallyDeclusterPassTest, DeclusterDependentNodes) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer_0 = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName("ClusteredProducer0")); + Node* clustered_producer_1 = + ops::BinaryOp("FakeBinary", clustered_producer_0, input, + builder.opts().WithName("ClusteredProducer1")); + ops::BinaryOp("FakeBinary", clustered_producer_1, input, + builder.opts().WithName("UnclusteredConsumer")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", {clustered_producer_1, 1}, input, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer_0->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_producer_1->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector unclustered_consumer_inputs, declustered_producer_1_inputs; + + ASSERT_TRUE(GetInputsForNode(*graph, "UnclusteredConsumer", + &unclustered_consumer_inputs)); + ASSERT_EQ(unclustered_consumer_inputs.size(), 2); + EXPECT_EQ(unclustered_consumer_inputs[0]->name(), + "ClusteredProducer1/declustered"); + EXPECT_EQ(unclustered_consumer_inputs[1]->name(), "Input"); + + ASSERT_TRUE(GetInputsForNode(*graph, "ClusteredProducer1/declustered", + &declustered_producer_1_inputs)); + ASSERT_EQ(declustered_producer_1_inputs.size(), 2); + EXPECT_EQ(declustered_producer_1_inputs[0]->name(), + "ClusteredProducer0/declustered"); + EXPECT_EQ(declustered_producer_1_inputs[1]->name(), "Input"); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index a5628b12a27..0a025a1fc0b 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -185,4 +185,26 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { return Status::OK(); } +gtl::optional GetXlaClusterForNode(const Node& node) { + const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr); + if (attr_value == nullptr) { + return gtl::nullopt; + } + Status s = AttrValueHasType(*attr_value, "string"); + if (!s.ok()) { + return gtl::nullopt; + } + return attr_value->s(); +} + +bool HasResourceInputOrOutput(const Node& node) { + return std::find(node.input_types().begin(), node.input_types().end(), + DT_RESOURCE) != node.input_types().end() || + std::find(node.output_types().begin(), node.output_types().end(), + DT_RESOURCE) != node.output_types().end(); +} + +void RemoveFromXlaCluster(NodeDef* node_def) { + node_def->mutable_attr()->erase(kXlaClusterAttr); +} } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index bcce082aaf6..bff76da6f9b 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/lib/gtl/optional.h" namespace tensorflow { @@ -44,6 +45,16 @@ bool HasForwardedRefInput(const Node& node); // the enclosing graph. Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles); +// Returns the XLA cluster in which `node` is placed if it is in an XLA cluster, +// otherwise returns nullopt. +gtl::optional GetXlaClusterForNode(const Node& node); + +// Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute). +void RemoveFromXlaCluster(NodeDef* node_def); + +// Returns true if `node` has a DT_RESOURCE typed input or output. +bool HasResourceInputOrOutput(const Node& node); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index f65f89ebf57..dd84fb34c17 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -78,7 +78,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, executable->Run(launch_context.arguments(), run_options); TF_RETURN_IF_ERROR(run_result.status()); - launch_context.PopulateOutputs(ctx, result, run_result.ConsumeValueOrDie()); + TF_RETURN_IF_ERROR(launch_context.PopulateOutputs( + ctx, result, run_result.ConsumeValueOrDie())); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 4ddeaebd3e4..2a2691a6a40 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/dma_helper.h" @@ -216,6 +217,8 @@ XlaDevice::XlaDevice( transfer_as_literal_(transfer_as_literal), shape_representation_fn_(shape_representation_fn) { VLOG(1) << "Created XLA device " << jit_device_name << " " << this; + thread_pool_.reset(new thread::ThreadPool(options.env, "xla_device", + /*num_threads=*/1)); } XlaDevice::~XlaDevice() { @@ -262,10 +265,12 @@ Status XlaDevice::EnsureDeviceContextOk() { Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend, const string& name, - xla::StreamPool::Ptr* stream, + std::shared_ptr* stream, bool* stream_was_changed) { if (!(*stream) || !(*stream)->ok()) { - TF_ASSIGN_OR_RETURN(*stream, backend->BorrowStream(device_ordinal_)); + xla::StreamPool::Ptr ptr; + TF_ASSIGN_OR_RETURN(ptr, backend->BorrowStream(device_ordinal_)); + *stream = std::shared_ptr(std::move(ptr)); VLOG(1) << "XlaDevice " << this << " new " << name << " " << (*stream)->DebugStreamPointers(); *stream_was_changed = true; @@ -281,8 +286,8 @@ xla::StatusOr XlaDevice::GetDeviceContextLocked() { TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_, &need_new_device_context)); - se::Stream* host_to_device_stream = stream_.get(); - se::Stream* device_to_host_stream = stream_.get(); + std::shared_ptr host_to_device_stream = stream_; + std::shared_ptr device_to_host_stream = stream_; if (use_multiple_streams_) { TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream", &host_to_device_stream_, @@ -290,8 +295,8 @@ xla::StatusOr XlaDevice::GetDeviceContextLocked() { TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream", &device_to_host_stream_, &need_new_device_context)); - host_to_device_stream = host_to_device_stream_.get(); - device_to_host_stream = device_to_host_stream_.get(); + host_to_device_stream = host_to_device_stream_; + device_to_host_stream = device_to_host_stream_; } if (!need_new_device_context) { @@ -304,9 +309,13 @@ xla::StatusOr XlaDevice::GetDeviceContextLocked() { if (device_context_) { device_context_->Unref(); } + // The XlaDeviceContext keeps a reference count to the streams, and the + // XlaDeviceContext remains live for the duration of a Executor run. This + // ensures that the streams remain live for the duration of a run, even if + // an error is encountered and the streams are replaced with new ones. device_context_ = new XlaDeviceContext( - stream_.get(), host_to_device_stream, device_to_host_stream, client(), - transfer_as_literal_, shape_representation_fn_); + stream_, host_to_device_stream, device_to_host_stream, client(), + transfer_as_literal_, shape_representation_fn_, thread_pool_.get()); VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext " << device_context_; @@ -371,6 +380,22 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, op_kernel->ComputeAsync(context, done); } +Status XlaDevice::Sync() { + VLOG(1) << "XlaDevice::Sync"; + std::shared_ptr stream; + { + mutex_lock lock(mu_); + stream = stream_; + } + if (!stream) return Status::OK(); + + if (!stream->parent()->SynchronizeAllActivity() || !stream->ok()) { + return errors::Internal("XlaDevice::Sync() failed."); + } + VLOG(1) << "XlaDevice::Sync completed"; + return Status::OK(); +} + Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, Tensor* tensor) { diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index d8906419b0c..dbf35f349f8 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/allocator.h" @@ -124,7 +123,7 @@ class XlaDevice : public LocalDevice { void Compute(OpKernel* op_kernel, OpKernelContext* context) override; void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) override; - Status Sync() override { return Status::OK(); } + Status Sync() override; Status FillContextMap(const Graph* graph, DeviceContextMap* device_context_map) override @@ -153,7 +152,7 @@ class XlaDevice : public LocalDevice { Allocator* GetAllocatorLocked(AllocatorAttributes attr) EXCLUSIVE_LOCKS_REQUIRED(mu_); Status EnsureStreamOkLocked(xla::Backend* backend, const string& name, - xla::StreamPool::Ptr* stream, + std::shared_ptr* stream, bool* stream_was_changed) EXCLUSIVE_LOCKS_REQUIRED(mu_); xla::StatusOr GetDeviceContextLocked() @@ -174,17 +173,17 @@ class XlaDevice : public LocalDevice { // stream are executed on the device. Operations include data // copying back and forth between CPU and the device, and // computations enqueued by XLA. - xla::StreamPool::Ptr stream_ GUARDED_BY(mu_); + std::shared_ptr stream_ GUARDED_BY(mu_); // If false, only stream_ is valid and all computation and transfers use // stream_. If true, computation is performed by stream_ and transfers are // performed by host_to_device/device_to_host_stream. const bool use_multiple_streams_; // If use_multiple_streams_, host to device transfers are performed using this // stream. - xla::StreamPool::Ptr host_to_device_stream_ GUARDED_BY(mu_); + std::shared_ptr host_to_device_stream_ GUARDED_BY(mu_); // If use_multiple_streams_, device to host transfers are performed using this // stream. - xla::StreamPool::Ptr device_to_host_stream_ GUARDED_BY(mu_); + std::shared_ptr device_to_host_stream_ GUARDED_BY(mu_); // Must we use XLA's transfer manager for correct host<->device transfers? if // false, we can use ThenMemcpy() instead. const bool transfer_as_literal_; @@ -198,6 +197,9 @@ class XlaDevice : public LocalDevice { // Holds extra information for GPU and TPU devices, e.g. the device context. bool use_gpu_device_info_ GUARDED_BY(mu_) = false; std::unique_ptr gpu_device_info_ GUARDED_BY(mu_); + + // Thread pool used for running closures + std::unique_ptr thread_pool_; }; // Builds OpKernel registrations on 'device' for the JIT operators diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 0100bf51ed2..0a0c0892411 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device_context.h" +#include + +#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -48,17 +51,20 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) { void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } XlaTransferManager::XlaTransferManager( - se::Stream* compute_stream, se::Stream* host_to_device_stream, - se::Stream* device_to_host_stream, xla::LocalClient* client, + std::shared_ptr compute_stream, + std::shared_ptr host_to_device_stream, + std::shared_ptr device_to_host_stream, xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn) - : stream_(compute_stream), - host_to_device_stream_(host_to_device_stream), - device_to_host_stream_(device_to_host_stream), + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool) + : stream_(std::move(compute_stream)), + host_to_device_stream_(std::move(host_to_device_stream)), + device_to_host_stream_(std::move(device_to_host_stream)), client_(client), transfer_manager_(client->backend().transfer_manager()), transfer_as_literal_(transfer_as_literal), - shape_representation_fn_(std::move(shape_representation_fn)) { + shape_representation_fn_(std::move(shape_representation_fn)), + thread_pool_(thread_pool) { CHECK(host_to_device_stream_ != nullptr); CHECK(device_to_host_stream_ != nullptr); CHECK(stream_ != nullptr); @@ -88,15 +94,15 @@ Status XlaTransferManager::TransferLiteralToDevice( if (UseMultipleStreams()) { // Initially wait for the compute stream so that memory allocations are // synchronized. - host_to_device_stream_->ThenWaitFor(stream_); + host_to_device_stream_->ThenWaitFor(stream_.get()); } TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync( - host_to_device_stream_, *literal, shaped_buffer)); + host_to_device_stream_.get(), *literal, shaped_buffer)); if (UseMultipleStreams()) { - se::Event event(stream_->parent()); - TF_RET_CHECK(event.Init()) << "Event failed to initialize!"; - host_to_device_stream_->ThenRecordEvent(&event); - xla_tensor->SetDefinedOn(host_to_device_stream_, std::move(event)); + auto event = std::make_shared(stream_->parent()); + TF_RET_CHECK(event->Init()) << "Event failed to initialize!"; + host_to_device_stream_->ThenRecordEvent(event.get()); + xla_tensor->SetDefinedOn(host_to_device_stream_.get(), std::move(event)); } // Unref the host tensor, and capture the literal shared_ptr too so it goes // out of scope when the lambda completes. @@ -116,7 +122,7 @@ void XlaTransferManager::TransferLiteralFromDevice( TensorReference ref(device_tensor); transfer_manager_->TransferLiteralFromDevice( - device_to_host_stream_, shaped_buffer, literal, + device_to_host_stream_.get(), shaped_buffer, literal, [=, &shaped_buffer, &literal](xla::Status status) { ref.Unref(); done([&]() -> Status { @@ -179,8 +185,14 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); if (status.ok()) { xla_tensor->set_host_tensor(*cpu_tensor); - host_to_device_stream_->ThenDoHostCallback( - [done]() { done(Status::OK()); }); + host_to_device_stream_->ThenDoHostCallback([this, done]() { + // We must not call the done closure directly from DoHostCallback + // to avoid a deadlock. If done() is the callback that ends an + // Executor's run, the Executor may call XlaDevice::Sync() inside the + // callback. This deadlocks, because XlaDevice::Sync() waits for all + // stream activity to complete. + thread_pool_->Schedule([done]() { done(Status::OK()); }); + }); return; } } else { @@ -192,7 +204,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, if (!block_status.ok()) { status = xla::InternalError( "Failed to complete data transfer on stream %p: %s", - host_to_device_stream_, block_status.error_message().c_str()); + host_to_device_stream_.get(), block_status.error_message().c_str()); } } xla_tensor->set_host_tensor(*cpu_tensor); @@ -225,9 +237,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); if (se::Event* event = - xla_tensor->GetDefinitionEvent(device_to_host_stream_)) { + xla_tensor->GetDefinitionEvent(device_to_host_stream_.get())) { device_to_host_stream_->ThenWaitFor(event); - xla_tensor->SetDefinedOn(device_to_host_stream_); + xla_tensor->SetDefinedOn(device_to_host_stream_.get()); } Status status; @@ -240,7 +252,7 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, Status block_status = device_to_host_stream_->BlockHostUntilDone(); if (!block_status.ok()) { status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", stream_, + "Failed to complete data transfer on stream %p: %s", stream_.get(), block_status.error_message().c_str()); } } @@ -278,14 +290,14 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, if (stream_ != device_to_device_stream) { // Initially wait for the compute stream so that memory allocations are // synchronized. - device_to_device_stream->ThenWaitFor(stream_); + device_to_device_stream->ThenWaitFor(stream_.get()); } } if (se::Event* event = - xla_src->GetDefinitionEvent(device_to_device_stream)) { + xla_src->GetDefinitionEvent(device_to_device_stream.get())) { device_to_device_stream->ThenWaitFor(event); - xla_src->SetDefinedOn(device_to_device_stream); + xla_src->SetDefinedOn(device_to_device_stream.get()); } auto from_iter = xla_src->shaped_buffer().buffers().begin(); @@ -297,28 +309,37 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, } if (UseMultipleStreams()) { - se::Event event(stream_->parent()); - CHECK(event.Init()); - device_to_device_stream->ThenRecordEvent(&event); - xla_dst->SetDefinedOn(device_to_device_stream, std::move(event)); + auto event = std::make_shared(stream_->parent()); + TF_RET_CHECK(event->Init()) << "Event failed to initialize"; + device_to_device_stream->ThenRecordEvent(event.get()); + xla_dst->SetDefinedOn(device_to_device_stream.get(), std::move(event)); } return Status::OK(); }(); if (!status.ok()) { return done(status); } else { - stream_->ThenDoHostCallback([=]() { done(Status::OK()); }); + stream_->ThenDoHostCallback([this, done]() { + // We must not call the done closure directly from DoHostCallback to avoid + // a deadlock. If done() is the callback that ends an Executor's run, the + // Executor may call XlaDevice::Sync() inside the callback. This + // deadlocks, because XlaDevice::Sync() waits for all stream activity to + // complete. + thread_pool_->Schedule([done]() { done(Status::OK()); }); + }); } } XlaDeviceContext::XlaDeviceContext( - se::Stream* compute_stream, se::Stream* host_to_device_stream, - se::Stream* device_to_host_stream, xla::LocalClient* client, + std::shared_ptr compute_stream, + std::shared_ptr host_to_device_stream, + std::shared_ptr device_to_host_stream, xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn) - : manager_(compute_stream, host_to_device_stream, device_to_host_stream, - client, transfer_as_literal, - std::move(shape_representation_fn)) {} + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool) + : manager_(std::move(compute_stream), std::move(host_to_device_stream), + std::move(device_to_host_stream), client, transfer_as_literal, + std::move(shape_representation_fn), thread_pool) {} void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 912f8d779e7..2e7445340cb 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -47,10 +47,12 @@ class XlaDeviceAllocator : public Allocator { class XlaTransferManager { public: explicit XlaTransferManager( - se::Stream* compute_stream, se::Stream* host_to_device_stream, - se::Stream* device_to_host_stream, xla::LocalClient* client, - bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn); + std::shared_ptr compute_stream, + std::shared_ptr host_to_device_stream, + std::shared_ptr device_to_host_stream, + xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const; @@ -61,7 +63,7 @@ class XlaTransferManager { void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done); - se::Stream* stream() const { return stream_; } + se::Stream* stream() const { return stream_.get(); } private: Status TransferLiteralToDevice(const Tensor& host_tensor, @@ -73,13 +75,13 @@ class XlaTransferManager { // The main compute stream of the device, used to synchronize the transfer // streams if they are set. - se::Stream* stream_; + std::shared_ptr stream_; // The stream to use for transferring data from host to device. Can be // idential to stream_, but must not be nullptr. - se::Stream* host_to_device_stream_; + std::shared_ptr host_to_device_stream_; // The stream to use for transferring data from device to host. Can be // idential to stream_, but must not be nullptr. - se::Stream* device_to_host_stream_; + std::shared_ptr device_to_host_stream_; // For the underlying memory allocator and XLA's TransferManager. xla::LocalClient* client_; // Transfer manager, for marshalling data to and from the device. @@ -87,6 +89,9 @@ class XlaTransferManager { // True if we must use XLA's TransferManager for correct device transfers. const bool transfer_as_literal_; XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + + // Thread pool used for running closures + thread::ThreadPool* thread_pool_; }; // DeviceContext for operators assigned to XlaDevice devices. The @@ -95,10 +100,12 @@ class XlaTransferManager { class XlaDeviceContext : public DeviceContext { public: explicit XlaDeviceContext( - se::Stream* compute_stream, se::Stream* host_to_device_stream, - se::Stream* device_to_host_stream, xla::LocalClient* client, - bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn); + std::shared_ptr compute_stream, + std::shared_ptr host_to_device_stream, + std::shared_ptr device_to_host_stream, + xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 6134b8c6946..4efbb2d5d7c 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_launch_util.h" +#include + #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -182,7 +184,7 @@ void XlaComputationLaunchContext::PopulateInputs( } } -void XlaComputationLaunchContext::PopulateOutputs( +Status XlaComputationLaunchContext::PopulateOutputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, ScopedShapedBuffer output) { se::Stream* stream = @@ -211,6 +213,15 @@ void XlaComputationLaunchContext::PopulateOutputs( output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator()); } + std::shared_ptr definition_event; + if (use_multiple_streams_) { + definition_event = std::make_shared(stream->parent()); + if (!definition_event->Init()) { + return errors::Internal("Failed to initialize tensor definition event."); + } + stream->ThenRecordEvent(definition_event.get()); + } + // Copy XLA results to the OpOutputList. int output_num = 0; for (int i = 0; i < ctx->num_outputs(); ++i) { @@ -228,12 +239,13 @@ void XlaComputationLaunchContext::PopulateOutputs( // reallocate the device buffer later. VLOG(1) << "Constant output tensor on device"; - OP_REQUIRES_OK( - ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); + TF_RETURN_IF_ERROR( + ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); Device* device = dynamic_cast(ctx->device()); - OP_REQUIRES(ctx, device != nullptr, - errors::Internal("DeviceBase was not a Device.")); + if (device == nullptr) { + return errors::Internal("DeviceBase was not a Device."); + } ctx->op_device_context()->CopyCPUTensorToDevice( &const_tensor, device, output_tensor, [&](Status status) { TF_CHECK_OK(status); }); @@ -263,16 +275,13 @@ void XlaComputationLaunchContext::PopulateOutputs( se::DeviceMemoryBase buffer = output.buffer({output_num}); if (allocate_xla_tensors_) { Tensor* output_tensor; - OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor)); + TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); if (xla_tensor) { xla_tensor->set_shaped_buffer(ScopedShapedBuffer( ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); if (use_multiple_streams_) { - se::Event event(stream->parent()); - CHECK(event.Init()); - stream->ThenRecordEvent(&event); - xla_tensor->SetDefinedOn(stream, std::move(event)); + xla_tensor->SetDefinedOn(stream, definition_event); } } else { // xla_tensor wasn't valid, which must mean this is a zero-element @@ -298,41 +307,39 @@ void XlaComputationLaunchContext::PopulateOutputs( for (int i = 0; i < kernel->resource_updates.size(); ++i) { Allocator* allocator = ctx->device()->GetAllocator({}); const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; - OP_REQUIRES(ctx, - write.input_index >= 0 && write.input_index < ctx->num_inputs(), - errors::Internal("Invalid input index for variable write.")); + if (write.input_index < 0 || write.input_index >= ctx->num_inputs()) { + return errors::Internal("Invalid input index for variable write."); + } se::DeviceMemoryBase buffer = output.buffer({output_num}); Var* variable = nullptr; // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, // not a Tensor. - OP_REQUIRES_OK(ctx, LookupOrCreateResource( - ctx, HandleFromInput(ctx, write.input_index), - &variable, [this, ctx, &write](Var** ptr) { - *ptr = new Var(write.type); - return Status::OK(); - })); + TF_RETURN_IF_ERROR(LookupOrCreateResource( + ctx, HandleFromInput(ctx, write.input_index), &variable, + [&write](Var** ptr) { + *ptr = new Var(write.type); + return Status::OK(); + })); core::ScopedUnref s(variable); mutex_lock ml(*variable->mu()); - OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type, - errors::Internal("Mismatched type in variable write")); + if (variable->tensor()->dtype() != write.type) { + return errors::Internal("Mismatched type in variable write"); + } if (allocate_xla_tensors_) { Tensor output_tensor; - OP_REQUIRES_OK( - ctx, ctx->allocate_temp(write.type, write.shape, &output_tensor)); + TF_RETURN_IF_ERROR( + ctx->allocate_temp(write.type, write.shape, &output_tensor)); XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor); CHECK(xla_tensor); xla_tensor->set_shaped_buffer( ExtractSubShapedBuffer(&output, output_num, xla_allocator_)); if (use_multiple_streams_) { - se::Event event(stream->parent()); - CHECK(event.Init()); - stream->ThenRecordEvent(&event); - xla_tensor->SetDefinedOn(stream, std::move(event)); + xla_tensor->SetDefinedOn(stream, definition_event); } *variable->tensor() = output_tensor; } else { @@ -343,6 +350,7 @@ void XlaComputationLaunchContext::PopulateOutputs( } ++output_num; } + return Status::OK(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 1ea3fa4cf29..4232f514b3b 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -93,9 +93,9 @@ class XlaComputationLaunchContext { const std::map& variables); // Given the XLA output in `output`, populate all outputs of `ctx`. - void PopulateOutputs(OpKernelContext* ctx, - const XlaCompiler::CompilationResult* kernel, - xla::ScopedShapedBuffer output); + Status PopulateOutputs(OpKernelContext* ctx, + const XlaCompiler::CompilationResult* kernel, + xla::ScopedShapedBuffer output); // Return the argument list. Only valid after PopulateInputs() has been // called. diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index d777dfa5a34..92ba7de1b7d 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -75,7 +75,7 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) { mutex_lock lock(mu_); - if (!definition_event_.has_value()) { + if (!definition_event_) { return nullptr; } @@ -87,10 +87,11 @@ se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) { return nullptr; } - return &*definition_event_; + return definition_event_.get(); } -void XlaTensor::SetDefinedOn(se::Stream* stream, se::Event event) { +void XlaTensor::SetDefinedOn(se::Stream* stream, + std::shared_ptr event) { mutex_lock lock(mu_); definition_event_ = std::move(event); streams_defined_on_ = {stream}; diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index f7e401c7311..8d36d0fa0a8 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ #define TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ +#include + #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/core/framework/allocator.h" @@ -94,7 +96,7 @@ class XlaTensor { // Assert that the tensor's content is defined on 'stream' by the time 'event' // triggers. - void SetDefinedOn(se::Stream* stream, se::Event event); + void SetDefinedOn(se::Stream* stream, std::shared_ptr event); // Assert that the tensor's content is defined on 'stream'. This version does // not provide an event, and must be called *after* SetDefinedOn(Stream, @@ -116,7 +118,7 @@ class XlaTensor { // An optional event that is triggered when the tensor's content has been // defined. If this event is nullptr, it is assumed that the tensor's content // is always defined. - gtl::optional definition_event_; + std::shared_ptr definition_event_; // A list of all streams for which the tensor's content is defined for any // newly enqueued command. gtl::InlinedVector streams_defined_on_ GUARDED_BY(mu_); diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index f42fb92359f..1bf8948ef6d 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -31,7 +31,6 @@ std::vector* flag_objects; std::once_flag flags_init; void SetDebugOptionsDefaults(DebugOptions* flags) { - flags->set_xla_enable_fast_math(true); flags->set_xla_llvm_enable_alias_scope_metadata(true); flags->set_xla_llvm_enable_noalias_metadata(true); flags->set_xla_llvm_enable_invariant_load_metadata(true); @@ -53,6 +52,11 @@ void SetDebugOptionsDefaults(DebugOptions* flags) { // the heuristics needed to decide when to run on multiple streams. See // b/77879207. flags->set_xla_gpu_disable_multi_streaming(true); + + // TODO(jlebar): Disable fastmath once doing so is not a performance + // regression. + flags->set_xla_cpu_enable_fast_math(true); + flags->set_xla_gpu_enable_fast_math(true); } // Allocates flag_values and flag_objects; this function must not be called more @@ -150,10 +154,16 @@ void AllocateFlags() { flag_values->mutable_xla_generate_hlo_text_to(), "Dump all HLO modules as text into the provided directory path."), tensorflow::Flag( - "xla_enable_fast_math", - bool_setter_for(&DebugOptions::set_xla_enable_fast_math), - flag_values->xla_enable_fast_math(), - "Enable unsafe fast-math optimizations in the compiler; " + "xla_cpu_enable_fast_math", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), + flag_values->xla_cpu_enable_fast_math(), + "Enable unsafe fast-math optimizations in the CPU compiler; " + "this may produce faster code at the expense of some accuracy."), + tensorflow::Flag( + "xla_gpu_enable_fast_math", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), + flag_values->xla_cpu_enable_fast_math(), + "Enable unsafe fast-math optimizations in the GPU compiler; " "this may produce faster code at the expense of some accuracy."), tensorflow::Flag( "xla_llvm_enable_alias_scope_metadata", diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 7d315fa0d3d..7331d2b54cf 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1234,6 +1234,20 @@ cc_library( ], ) +cc_library( + name = "scatter_expander", + srcs = ["scatter_expander.cc"], + hdrs = ["scatter_expander.h"], + deps = [ + ":hlo", + ":hlo_creation_utils", + ":hlo_pass", + ":while_util", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + ], +) + tf_cc_test( name = "batchnorm_expander_test", size = "small", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 37834e1cc26..f7812d96614 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1705,6 +1705,10 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { reshape, HloInstruction::CreateReshape(reshape->shape(), operand->mutable_operand(0))); } + if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { + *operand->mutable_shape() = reshape->shape(); + return ReplaceInstruction(reshape, operand); + } if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) { auto opt_dims = ReshapeLeavesDimensionsUnmodified( @@ -2144,6 +2148,11 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { transpose->dimensions()))); } + if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { + *operand->mutable_shape() = transpose->shape(); + return ReplaceInstruction(transpose, operand); + } + if (is_layout_sensitive_ && TransposeIsBitcast(transpose)) { ReplaceWithBitcast(transpose); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 862cbeeba6b..5837391d759 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -1428,6 +1428,37 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); } +// Test transforming reshapes and transposes of rng. +TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { + HloComputation::Builder builder(TestName()); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction* one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction* rng0 = builder.AddInstruction( + HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {2, 2}), + RandomDistribution::RNG_UNIFORM, {zero, one})); + + HloInstruction* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(rng0->shape(), rng0, {1, 0})); + Shape reshape_shape = builder + .AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {4}), transpose)) + ->shape(); + + auto computation = module().AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + // Verify that that reshape(transpose(rng)) is replace by a single rng of the + // same shape as the reshape. + EXPECT_THAT(computation->root_instruction(), op::Rng()); + EXPECT_TRUE(ShapeUtil::Equal(computation->root_instruction()->shape(), + reshape_shape)); +} + // Test transforming reshapes to bitcasts under various conditions. TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { HloComputation::Builder builder(TestName()); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 118a11c8de3..cfd26fc778c 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -139,6 +139,7 @@ Status GatherComputationsByAllocationType( case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: + case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: case HloOpcode::kFusion: // Map/reduce etc computations are always thread-local. diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index a23427f00cc..985ff30e80a 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -61,6 +61,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: + case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: case HloOpcode::kFusion: return CallContext::kParallel; diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 36fb9b43aa2..3e39c1bab1e 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -312,7 +312,7 @@ Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis, return Status::OK(); } -// We add copies for all the indices of the true and false computaiton roots, +// We add copies for all the indices of the true and false computation roots, // in order to resolve interference. We later rely on the CopyRemover to drop // the unnecessary ones. Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, @@ -648,7 +648,12 @@ class CopyRemover { // We can only perform copy elision if the resulting merged values have // totally ordered live ranges; otherwise the merged buffer would have // live range interference. - if (IsHead(*dest)) { + if (src->next == dest) { + // In the process of eliding copies, its possible for a copy to have the + // same source and destination buffer. In this case, the copy can be + // safely removed. + VLOG(2) << copy->name() << " source and destination buffers are same."; + } else if (IsHead(*dest)) { // The copy copies an arbitrary value in the source buffer (call it s_x) // and defines d_0, the first value in the destination buffer. After // merging, the values in the combined buffer must be strictly ordered diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index cd735256b83..892d0d7b547 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -2007,5 +2007,46 @@ ENTRY TestComputation { InsertCopies(module.get()); } +TEST_F(CopyInsertionTest, NestedWhiles) { + // Verify that only no unnecessary copies remain after copy insertion for + // trivial nested whiles (b/112472605). + const string& hlo_string = R"( +HloModule TestModule + +cond.inner { + ROOT param.cond.inner = pred[] parameter(0) +} + +body.inner { + param.body.inner = pred[] parameter(0) + ROOT neg = pred[] negate(param.body.inner) +} + +cond.outer { + ROOT param.cond.outer = pred[] parameter(0) +} + +body.outer { + param.cond.outer = pred[] parameter(0) + ROOT while = pred[] while(param.cond.outer), condition=cond.inner, body=body.inner +} + +ENTRY TestComputation { + entry_param = pred[] parameter(0) + ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + InsertCopies(module.get()); + + // There should only be a single copy inserted, and it's in the entry + // computation. + EXPECT_EQ(CountCopies(*module), 1); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::While(op::Copy(op::Parameter()))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 3efe3e2f93a..84779c60b0c 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -20,7 +20,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS") load( "//third_party/mkl:build_defs.bzl", - "if_mkl", + "mkl_deps", ) # Filegroup used to collect source files for dependency checking. @@ -86,6 +86,7 @@ cc_library( ":parallel_task_assignment", ":simple_orc_jit", "//tensorflow/compiler/tf2xla:cpu_function_runtime", + "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", @@ -497,10 +498,7 @@ cc_library( "//tensorflow/core:framework_lite", "//tensorflow/core/kernels:eigen_helpers", "//third_party/eigen3", - ] + if_mkl([ - "@mkl_dnn", - "//third_party/mkl:intel_binary_blob", - ]), + ] + mkl_deps(), ) cc_library( @@ -554,10 +552,7 @@ cc_library( "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", "//third_party/eigen3", - ] + if_mkl([ - "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ]), + ] + mkl_deps(), ) cc_library( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 2df959c4dc5..35154af0482 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -88,6 +88,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/scatter_expander.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" @@ -299,6 +300,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pipeline.AddPass(/*is_layout_sensitive=*/false); pipeline.AddPass(); + pipeline.AddPass(); + ReducePrecisionInsertion::AddPasses( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); @@ -356,7 +359,7 @@ llvm::TargetOptions CompilerTargetOptions( llvm::TargetOptions target_options; llvm_ir::SetTargetOptions( /*fast_math_enabled=*/module_config.debug_options() - .xla_enable_fast_math(), + .xla_cpu_enable_fast_math(), &target_options); return target_options; } @@ -523,7 +526,7 @@ StatusOr> CpuCompiler::RunBackend( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config()), options::OptimizeForSizeRequested(module->config()), - module->config().debug_options().xla_enable_fast_math(), + module->config().debug_options().xla_cpu_enable_fast_math(), module->config().debug_options().xla_llvm_disable_expensive_passes(), pre_optimization_ir_hook, post_optimization_ir_hook); llvm_module->setDataLayout(jit->data_layout()); @@ -653,9 +656,9 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, // so we bail if the configs have conflicting flags. At the moment, the only // flag that needs to be consistent is fast-math. const bool fast_math_enabled = - modules[0]->config().debug_options().xla_enable_fast_math(); + modules[0]->config().debug_options().xla_cpu_enable_fast_math(); for (const auto& module : modules) { - if (module->config().debug_options().xla_enable_fast_math() != + if (module->config().debug_options().xla_cpu_enable_fast_math() != fast_math_enabled) { return InvalidArgument( "All HLO module configs must have the same value for " @@ -832,7 +835,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, CompilerFunctor compiler_functor( target_machine.get(), &disassembler, opt_level, options::OptimizeForSizeRequested(module->config()), - module->config().debug_options().xla_enable_fast_math(), + module->config().debug_options().xla_cpu_enable_fast_math(), module->config().debug_options().xla_llvm_disable_expensive_passes(), pre_optimization_ir_dump_hook, post_optimization_ir_dump_hook); std::unique_ptr object_file = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 946f5124b87..c376864c3e1 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -249,24 +249,11 @@ StatusOr CpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { - if (GetRootPointsToSet().IsAmbiguous()) { - return Unimplemented("Points-to set of root instruction is ambiguous"); - } - - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - - std::vector owning_buffers; - std::vector unowning_buffers; TF_ASSIGN_OR_RETURN( - std::tie(unowning_buffers, owning_buffers), - CreateTempArray(memory_allocator, stream->parent()->device_ordinal(), - arguments)); - - TF_RETURN_IF_ERROR(ExecuteComputeFunction( - &run_options->run_options(), unowning_buffers, hlo_execution_profile)); - - return CreateResultShapedBuffer(run_options, &owning_buffers); + auto result, + ExecuteAsyncOnStreamImpl(run_options, arguments, hlo_execution_profile)); + TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone()); + return std::move(result); } StatusOr CpuExecutable::ExecuteAsyncOnStream( @@ -277,6 +264,16 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( "Asynchronous execution on stream with hlo profiling is not yet " "supported on CPU."); } + return ExecuteAsyncOnStreamImpl(run_options, arguments, nullptr); +} + +StatusOr CpuExecutable::ExecuteAsyncOnStreamImpl( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + HloExecutionProfile* hlo_execution_profile) { + if (GetRootPointsToSet().IsAmbiguous()) { + return Unimplemented("Points-to set of root instruction is ambiguous"); + } auto* host_stream = dynamic_cast( run_options->stream()->implementation()); @@ -310,19 +307,20 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( ServiceExecutableRunOptions run_options; std::vector unowning_buffers; std::shared_ptr> buffers; + HloExecutionProfile* hlo_execution_profile; void operator()() { // Failing a CHECK here is not great, but I don't see an obvious way to // return a failed Status asynchronously. TF_CHECK_OK(executable->ExecuteComputeFunction( - &run_options.run_options(), unowning_buffers, - /*hlo_execution_profile=*/nullptr)); + &run_options.run_options(), unowning_buffers, hlo_execution_profile)); } }; host_stream->EnqueueTask( AsyncRunTask{this, *run_options, std::move(unowning_buffers), std::make_shared>( - std::move(owning_buffers))}); + std::move(owning_buffers)), + hlo_execution_profile}); return std::move(result); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 8af8a5dfec2..96e53de57ee 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -85,6 +85,16 @@ class CpuExecutable : public Executable { const BufferAssignment& buffer_assignment() const { return *assignment_; } private: + // This is for sharing the code between ExecuteOnStream and + // ExecuteAsyncOnStream. + // + // Notice that it's tricky to use correctly, as the profile object (when it + // exists) must out-live the task. + StatusOr ExecuteAsyncOnStreamImpl( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + HloExecutionProfile* hlo_execution_profile); + // Creates an array suitable for passing as the "temps" argument to the JIT // compiled function pointer. // diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 645888de783..f2ac742b6e6 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -1066,7 +1066,7 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( << config.GetCacheKey(); const bool enable_fast_math = - hlo_module_config_.debug_options().xla_enable_fast_math(); + hlo_module_config_.debug_options().xla_cpu_enable_fast_math(); const bool optimize_for_size = options::OptimizeForSizeRequested(hlo_module_config_); @@ -1149,7 +1149,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer(); const bool enable_fast_math = - hlo_module_config_.debug_options().xla_enable_fast_math(); + hlo_module_config_.debug_options().xla_cpu_enable_fast_math(); const bool optimize_for_size = options::OptimizeForSizeRequested(hlo_module_config_); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 09909b62ba4..6f433b4f303 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -99,7 +99,7 @@ IrEmitter::IrEmitter( target_machine_features_(*target_machine_features) { b_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() - .xla_enable_fast_math())); + .xla_cpu_enable_fast_math())); } StatusOr IrEmitter::EmitComputation( @@ -158,11 +158,11 @@ void IrEmitter::InitializeIrFunction(const string& function_name) { is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage : llvm::GlobalValue::InternalLinkage; // Create and initialize new IrFunction. - compute_function_.reset( - new IrFunction(function_name, linkage, - options::OptimizeForSizeRequested(hlo_module_config_), - hlo_module_config_.debug_options().xla_enable_fast_math(), - module_, &b_, num_dynamic_loop_bounds_)); + compute_function_.reset(new IrFunction( + function_name, linkage, + options::OptimizeForSizeRequested(hlo_module_config_), + hlo_module_config_.debug_options().xla_cpu_enable_fast_math(), module_, + &b_, num_dynamic_loop_bounds_)); } IrEmitter::~IrEmitter() {} @@ -577,7 +577,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*reduce_window, /*operands=*/{reduce_window->operand(0)}, - /*supported_types=*/{F32, BF16, S32})); + /*supported_types=*/{F32, BF16, S32, F16})); // TODO(b/31410564): Implement dilation for reduce-window. if (window_util::HasDilation(reduce_window->window())) { diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc index d938f3a2c4b..48e44714998 100644 --- a/tensorflow/compiler/xla/service/despecializer.cc +++ b/tensorflow/compiler/xla/service/despecializer.cc @@ -21,8 +21,33 @@ limitations under the License. namespace xla { +namespace { + +// Pass which strips control dependencies from all instructions in the module. +class ControlDepRemover : public HloPassInterface { + public: + ControlDepRemover() = default; + tensorflow::StringPiece name() const override { + return "control-dep-remover"; + } + + StatusOr Run(HloModule* module) override { + bool changed = false; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + changed = changed || !instruction->control_predecessors().empty(); + TF_RETURN_IF_ERROR(instruction->DropAllControlDeps()); + } + } + return changed; + } +}; + +} // namespace + Despecializer::Despecializer() : pipeline_("despecializer") { // TODO(b/70588125): Also deal with window reversal in a fast way. + pipeline_.AddPass(); pipeline_.AddPass(); pipeline_.AddPass(); pipeline_.AddPass(); diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index a3f6e8d9893..19575c7905b 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1,6 +1,7 @@ # Description: # GPU-specific components in XLA service implementation. +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") licenses(["notice"]) # Apache 2.0 @@ -365,6 +366,7 @@ cc_library( ":gpu_executable", ":ir_emission_utils", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", @@ -652,6 +654,7 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", @@ -852,3 +855,35 @@ tf_cc_test( "//tensorflow/core:test", ], ) + +cc_library( + name = "buffer_comparator", + srcs = ["buffer_comparator.cc"], + hdrs = ["buffer_comparator.h"], + deps = [ + ":gpu_executable", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +xla_test( + name = "buffer_comparator_test", + srcs = ["buffer_comparator_test.cc"], + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":buffer_comparator", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:backend", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc new file mode 100644 index 00000000000..6a285a6b989 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -0,0 +1,205 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" + +#include +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace xla { +namespace gpu { + +static constexpr float kTolerance = 0.1f; + +static string GetCompHloText(size_t num_elements) { + // Implements the textual format of the comparison routine, as it's more + // readable. + static constexpr char kF16CompHloText[] = R"( +HloModule CompareF16 + +MaxF32 { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %max = f32[] maximum(%lhs, %rhs) +} + +Canonicalize (aparam: f16[SIZE]) -> f32[SIZE] { + %min_constant = f32[] constant(-65505) + %max_constant = f32[] constant(65505) + %large_constant = f32[] constant(1048576) + %min_values = f32[SIZE] broadcast(%min_constant), dimensions={} + %max_values = f32[SIZE] broadcast(%max_constant), dimensions={} + %large_values = f32[SIZE] broadcast(%large_constant), dimensions={} + + %a = f16[SIZE] parameter(0) + %converted = f32[SIZE] convert(%a) + %clamped = f32[SIZE] clamp(%min_values, %converted, %max_values) + + // Since the clamp() above already took care of infs, only NaNs will cause + // is-finite() to return false. + %is_finite = pred[SIZE] is-finite(%clamped) + ROOT %result = f32[SIZE] select(%is_finite, %clamped, %large_values) +} + +ENTRY MaxDifference { + %one_constant = f32[] constant(1.0) + %zero_constant = f32[] constant(0.0) + + %ones = f32[SIZE] broadcast(%one_constant), dimensions={} + + %lhs = f16[SIZE] parameter(0) + %rhs = f16[SIZE] parameter(1) + %lhs_canonical = f32[SIZE] call(%lhs), to_apply=Canonicalize + %rhs_canonical = f32[SIZE] call(%rhs), to_apply=Canonicalize + %sub = f32[SIZE] subtract(%lhs_canonical, %rhs_canonical) + %sub_abs = f32[SIZE] abs(%sub) + %lhs_abs = f32[SIZE] abs(%lhs_canonical) + %rhs_abs = f32[SIZE] abs(%rhs_canonical) + %max = f32[SIZE] maximum(%lhs_abs, %rhs_abs) + %denominator = f32[SIZE] add(%max, %ones) + %error = f32[SIZE] divide(%sub_abs, %denominator) + ROOT %max_diff = f32[] reduce(%error, %zero_constant), dimensions={0}, to_apply=MaxF32 +})"; + auto size_string = std::to_string(num_elements); + return tensorflow::str_util::StringReplace( + kF16CompHloText, "SIZE", {size_string.data(), size_string.size()}, true); +} + +StatusOr F16BufferComparator::Create( + se::DeviceMemory ref_buffer, Compiler* compiler, + DeviceMemoryAllocator* allocator, se::Stream* stream) { + auto stream_exec = stream->parent(); + int64 num_elements = ref_buffer.ElementCount(); + + // One may consider using hlo_runner to do all the compilation and execution. + // However, as of the time hlo_runner doesn't support injection for Compiler*, + // Stream*, or even the allocator. We may revisit this in the future if it + // proves to be a maintenance burden. + TF_ASSIGN_OR_RETURN( + auto exec, ([&]() -> StatusOr> { + HloModuleConfig config; + DebugOptions debug_options; + debug_options.set_xla_backend_optimization_level(2); + config.set_debug_options(debug_options); + TF_ASSIGN_OR_RETURN( + auto module, ParseHloString(GetCompHloText(num_elements), config)); + TF_ASSIGN_OR_RETURN( + module, + compiler->RunHloPasses(std::move(module), stream_exec, nullptr)); + return compiler->RunBackend(std::move(module), stream_exec, nullptr); + }())); + + TF_ASSIGN_OR_RETURN( + auto shaped_buffer, ([&]() -> StatusOr { + auto device_ordinal = stream_exec->device_ordinal(); + TF_ASSIGN_OR_RETURN( + auto owning_buffer, + allocator->Allocate(device_ordinal, ref_buffer.size())); + se::DeviceMemory buffer( + owning_buffer.AsDeviceMemoryBase()); + stream->ThenMemcpy(&buffer, ref_buffer, ref_buffer.size()); + Shape shape = ShapeUtil::MakeShape(xla::F16, {num_elements}); + ScopedShapedBuffer ret(shape, shape, allocator, device_ordinal); + ret.set_buffer(std::move(owning_buffer), {}); + return std::move(ret); + }())); + + return F16BufferComparator(stream, allocator, std::move(exec), + std::move(shaped_buffer)); +} + +StatusOr F16BufferComparator::CompareEqualImpl( + se::DeviceMemory test_buffer) { + if (ref_buffer_.root_buffer().size() != test_buffer.size()) { + return InternalError("Mismatched buffer size: %lld vs %lld", + ref_buffer_.root_buffer().size(), test_buffer.size()); + } + + int64 num_elements = test_buffer.ElementCount(); + + TF_ASSIGN_OR_RETURN( + auto result_buffer, ([&]() -> StatusOr { + auto stream_exec = stream_->parent(); + Shape shape = ShapeUtil::MakeShape(xla::F16, {num_elements}); + auto device_ordinal = stream_exec->device_ordinal(); + ShapedBuffer shaped_test_buffer(shape, shape, stream_exec->platform(), + device_ordinal); + shaped_test_buffer.set_buffer(test_buffer, {}); + ExecutableRunOptions run_options; + run_options.set_device_ordinal(stream_exec->device_ordinal()); + run_options.set_stream(stream_); + run_options.set_allocator(allocator_); + ServiceExecutableRunOptions service_run_options(run_options); + return exec_->ExecuteOnStream( + &service_run_options, {&ref_buffer_, &shaped_test_buffer}, nullptr); + }())); + + float result; + CHECK(result_buffer.root_buffer().size() == sizeof(result)); + stream_->ThenMemcpy(&result, result_buffer.root_buffer(), sizeof(result)); + TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone()); + return result < kTolerance; +} + +StatusOr F16BufferComparator::CompareEqual( + se::DeviceMemory test_buffer) { + TF_ASSIGN_OR_RETURN(auto result, CompareEqualImpl(test_buffer)); + if (result) { + return true; + } + // Host side code that does the same thing, but report some of the + // differences as well. + int64 n = test_buffer.ElementCount(); + std::vector host_ref_buffer(n), host_test_buffer(n); + stream_->ThenMemcpy(host_ref_buffer.data(), ref_buffer_.root_buffer(), + ref_buffer_.root_buffer().size()); + stream_->ThenMemcpy(host_test_buffer.data(), test_buffer, test_buffer.size()); + TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone()); + + const auto canonicalize = [](float a) -> float { + constexpr float kBigNumer = 1048576.; + constexpr float kMaxFp16Value = 65504.; + if (std::isnan(a)) { + return kBigNumer; + } + if (std::isinf(a)) { + if (a < 0) { + return -(kMaxFp16Value + 1); + } + return kMaxFp16Value + 1; + } + return a; + }; + int differences_seen = 0; + for (int64 i = 0; i < n && differences_seen < 10; i++) { + float original_ref = static_cast(host_ref_buffer[i]); + float original_test = static_cast(host_test_buffer[i]); + float ref = canonicalize(original_ref); + float test = canonicalize(original_test); + if (!(std::abs(ref - test) / (std::max(std::abs(ref), std::abs(test)) + 1) < + kTolerance)) { + differences_seen++; + LOG(ERROR) << "Difference at " << i << ": " << original_ref << " vs " + << original_test; + } + } + + return false; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.h b/tensorflow/compiler/xla/service/gpu/buffer_comparator.h new file mode 100644 index 00000000000..bf2ba78ceac --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.h @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ + +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// A fp16 comparator that internally keeps a reference buffer, and compares it +// against other test buffers. +class F16BufferComparator { + public: + F16BufferComparator(const F16BufferComparator&) = delete; + F16BufferComparator(F16BufferComparator&&) = default; + + // Creates a new comparator. It internally allocates a buffer initialized by + // ref_buffer. + static StatusOr Create( + se::DeviceMemory ref_buffer, Compiler* compiler, + DeviceMemoryAllocator* allocator, se::Stream* stream); + + // Returns true if the internally allocated buffer "compares equal" to + // test_buffer. The definition of "equal" is: + // * All NaNs equal. + // * All infs are treated as 65505 or -65505, so that this checker is tolerant + // to fp16 overflows. + // * With NaNs and infs taken care of, a and b compare equal iff: + // abs(a - b) / (max(abs(a), abs(b)) + 1) < tolerance + // + // See the implementation for the tolerance value. + StatusOr CompareEqual(se::DeviceMemory test_buffer); + + private: + F16BufferComparator(se::Stream* stream, DeviceMemoryAllocator* allocator, + std::unique_ptr exec, + ScopedShapedBuffer ref_buffer) + : stream_(stream), + allocator_(allocator), + exec_(std::move(exec)), + ref_buffer_(std::move(ref_buffer)) {} + + StatusOr CompareEqualImpl(se::DeviceMemory test_buffer); + + se::Stream* stream_; + DeviceMemoryAllocator* allocator_; + std::unique_ptr exec_; + ScopedShapedBuffer ref_buffer_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc new file mode 100644 index 00000000000..33761d1bd88 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc @@ -0,0 +1,126 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" + +#include +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class BufferComparatorTest : public testing::Test { + protected: + BufferComparatorTest() + : backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()), + stream_exec_(backend_->default_stream_executor()), + allocator_(stream_exec_->platform(), {stream_exec_}), + compiler_(Compiler::GetForPlatform(stream_exec_->platform()) + .ConsumeValueOrDie()) {} + + // Take floats only for convenience. Still uses half internally. + bool CompareEqualFloatBuffers(const std::vector& lhs_float, + const std::vector& rhs_float) { + std::vector lhs(lhs_float.begin(), lhs_float.end()); + std::vector rhs(rhs_float.begin(), rhs_float.end()); + se::Stream stream(stream_exec_); + stream.Init(); + + auto owning_lhs_buffer = + allocator_ + .Allocate(stream_exec_->device_ordinal(), lhs.size() * sizeof(half)) + .ConsumeValueOrDie(); + + auto owning_rhs_buffer = + allocator_ + .Allocate(stream_exec_->device_ordinal(), rhs.size() * sizeof(half)) + .ConsumeValueOrDie(); + + auto lhs_buffer = + se::DeviceMemory(owning_lhs_buffer.AsDeviceMemoryBase()); + auto rhs_buffer = + se::DeviceMemory(owning_rhs_buffer.AsDeviceMemoryBase()); + + stream.ThenMemcpy(&lhs_buffer, lhs.data(), lhs_buffer.size()); + stream.ThenMemcpy(&rhs_buffer, rhs.data(), rhs_buffer.size()); + + TF_CHECK_OK(stream.BlockHostUntilDone()); + + return F16BufferComparator::Create(lhs_buffer, compiler_, &allocator_, + &stream) + .ConsumeValueOrDie() + .CompareEqual(rhs_buffer) + .ConsumeValueOrDie(); + } + + std::unique_ptr backend_; + se::StreamExecutor* stream_exec_; + StreamExecutorMemoryAllocator allocator_; + Compiler* compiler_; +}; + +TEST_F(BufferComparatorTest, TestNaNs) { + EXPECT_TRUE(CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("")})); + // NaN values with different bit patterns should compare equal. + EXPECT_TRUE(CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("1234")})); + EXPECT_FALSE(CompareEqualFloatBuffers({std::nanf("")}, {1.})); +} + +TEST_F(BufferComparatorTest, TestInfs) { + const auto inf = std::numeric_limits::infinity(); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {std::nanf("")})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {inf})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {65504})); + EXPECT_TRUE(CompareEqualFloatBuffers({-inf}, {-65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {65504})); + + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-20})); +} + +TEST_F(BufferComparatorTest, TestNumbers) { + EXPECT_TRUE(CompareEqualFloatBuffers({20}, {20.1})); + EXPECT_FALSE(CompareEqualFloatBuffers({0}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({0.9}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10})); + EXPECT_TRUE(CompareEqualFloatBuffers({10}, {9})); +} + +TEST_F(BufferComparatorTest, TestMultiple) { + EXPECT_TRUE(CompareEqualFloatBuffers({20, 30, 40, 50, 60}, + {20.1, 30.1, 40.1, 50.1, 60.1})); + std::vector lhs(200); + std::vector rhs(200); + for (int i = 0; i < 200; i++) { + EXPECT_TRUE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the same at index " << i; + lhs[i] = 3; + rhs[i] = 5; + EXPECT_FALSE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the different at index " << i; + lhs[i] = 0; + rhs[i] = 0; + } +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 7348307ec8a..7d93bdfc8b1 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -30,7 +30,6 @@ namespace { using se::DeviceMemoryBase; using se::dnn::AlgorithmConfig; using se::dnn::AlgorithmDesc; -using tensorflow::gtl::nullopt; using tensorflow::gtl::optional; class ScratchAllocator : public se::ScratchAllocator { @@ -173,7 +172,7 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { // cache misses and doing extra work. Overall, caching doesn't seem worth the // trouble, but we may want to revisit this if we ever find a model where // caching would speed up compilation a lot. -optional> +StatusOr> CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, @@ -206,45 +205,25 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( // Allocate space for the input, filter, and output of the convolution. We // use a ScratchAllocator for this instead of calling allocator_ directly so // that our allocations don't leak. - // - // We don't put any data in these buffers, because (in theory, anyway) the - // speed of a conv isn't affected by the data being convolved. ScratchAllocator input_output_allocator(device_ordinal, allocator); - StatusOr maybe_input_buf = - input_output_allocator.AllocateBytes(&stream, - ShapeUtil::ByteSizeOf(input_shape)); - StatusOr maybe_filter_buf = - input_output_allocator.AllocateBytes(&stream, - ShapeUtil::ByteSizeOf(filter_shape)); - StatusOr maybe_output_buf = - input_output_allocator.AllocateBytes(&stream, - ShapeUtil::ByteSizeOf(output_shape)); - if (!maybe_input_buf.ok() || !maybe_filter_buf.ok() || - !maybe_output_buf.ok()) { - LOG(WARNING) - << "Couldn't allocate space for input/filter/output of convolution " - << instr->ToString() << ". Falling back to default algorithm."; - return nullopt; - } - - DeviceMemoryBase input_buf = maybe_input_buf.ValueOrDie(); - DeviceMemoryBase filter_buf = maybe_filter_buf.ValueOrDie(); - DeviceMemoryBase output_buf = maybe_output_buf.ValueOrDie(); + TF_ASSIGN_OR_RETURN(DeviceMemoryBase input_buf, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(input_shape))); + TF_ASSIGN_OR_RETURN(DeviceMemoryBase filter_buf, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(filter_shape))); + TF_ASSIGN_OR_RETURN(DeviceMemoryBase output_buf, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(output_shape))); // Although we don't have evidence this matters, zero out the buffers before // autotuning. It's conceivable that using uninitialized memory as the inputs // might affect performance if e.g. the inputs contain denormals, and this is // easy enough. - if (!stream.ThenMemZero(&input_buf, input_buf.size()) - .ThenMemZero(&filter_buf, filter_buf.size()) - .ThenMemZero(&output_buf, output_buf.size()) - .BlockHostUntilDone() - .ok()) { - LOG(WARNING) - << "Couldn't zero out input/filter/output buffer for convolution " - << instr->ToString() << ". Falling back to default algorithm."; - return nullopt; - } + TF_RETURN_IF_ERROR(stream.ThenMemZero(&input_buf, input_buf.size()) + .ThenMemZero(&filter_buf, filter_buf.size()) + .ThenMemZero(&output_buf, output_buf.size()) + .BlockHostUntilDone()); const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo( input_shape, output_shape, dnums, stream_exec_); @@ -292,9 +271,10 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( best_result_bytes_used); } - LOG(WARNING) << "All algorithms tried for convolution " << instr->ToString() - << " failed. Falling back to default algorithm."; - return nullopt; + return InternalError( + "All algorithms tried for convolution %s failed. Falling back to " + "default algorithm.", + instr->ToString().c_str()); } StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( @@ -305,12 +285,13 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( const auto& lhs_shape = instr->operand(0)->shape(); const auto& rhs_shape = instr->operand(1)->shape(); const auto& conv_result_shape = instr->shape().tuple_shapes(0); - optional> alg_scratch_and_tc; + StatusOr> alg_scratch_and_tc; if (call_target == kCudnnConvForwardCallTarget) { - alg_scratch_and_tc = PickBestAlgorithm( - CudnnConvKind::kForward, /*input_shape=*/lhs_shape, - /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape, - instr->window(), instr->convolution_dimension_numbers(), instr); + alg_scratch_and_tc = + PickBestAlgorithm(CudnnConvKind::kForward, /*input_shape=*/lhs_shape, + /*filter_shape=*/rhs_shape, + /*output_shape=*/conv_result_shape, instr->window(), + instr->convolution_dimension_numbers(), instr); } else if (call_target == kCudnnConvBackwardInputCallTarget) { alg_scratch_and_tc = PickBestAlgorithm( CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape, @@ -326,7 +307,8 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( << instr->ToString(); } - if (!alg_scratch_and_tc.has_value()) { + if (!alg_scratch_and_tc.ok()) { + LOG(ERROR) << alg_scratch_and_tc.status(); return false; } @@ -334,7 +316,8 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( bool tensor_ops_enabled; int64 scratch_bytes; - std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = *alg_scratch_and_tc; + std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = + alg_scratch_and_tc.ConsumeValueOrDie(); VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and " << NumBytesToString(scratch_bytes) diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h index bc5d1ce94af..8b7749628a8 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ +#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -34,8 +35,9 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { // memory while timing the various convolution algorithms. If it's null, // we'll use the default allocator on the StreamExecutor. CudnnConvolutionAlgorithmPicker(se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* allocator) - : stream_exec_(stream_exec), allocator_(allocator) {} + DeviceMemoryAllocator* allocator, + Compiler* compiler) + : stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {} tensorflow::StringPiece name() const override { return "cudnn-convolution-algorithm-picker"; @@ -46,13 +48,14 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { private: StatusOr RunOnComputation(HloComputation* computation); StatusOr RunOnInstruction(HloInstruction* instr); - tensorflow::gtl::optional> PickBestAlgorithm( + StatusOr> PickBestAlgorithm( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, const ConvolutionDimensionNumbers& dnums, HloInstruction* instr); se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null + Compiler* compiler_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 69ba91793dd..9b6de115ad7 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -210,11 +210,13 @@ StatusOr GpuElementalIrEmitter::EmitPowerOp( return make_sqrt(); } - if (hlo_module_config_.debug_options().xla_enable_fast_math() && - IsFPLiteralWithValue(rhs, -.5)) { + if (IsFPLiteralWithValue(rhs, -.5)) { VLOG(10) << "emitting pow(A, -.5) as 1/sqrt(A): " << op->ToString(); // LLVM's NVPTX backend knows how to transform 1/sqrt(A) into the NVPTX // rsqrt.approx instruction. + // + // TODO(jlebar): Does this happen with fastmath disabled? If not, should + // we force-enable it? TF_ASSIGN_OR_RETURN(auto* sqrt, make_sqrt()); return b_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt); } @@ -274,16 +276,18 @@ StatusOr GpuElementalIrEmitter::EmitAtan2( StatusOr GpuElementalIrEmitter::EmitTanh( PrimitiveType prim_type, llvm::Value* value) const { - // If we don't care much about precision, emit a fast approximation of - // tanh. - if (hlo_module_config_.debug_options().xla_enable_fast_math()) { - // Upcast F16 to F32 if necessary. - llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType(); - llvm::Value* input = b_->CreateFPCast(value, type); - llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input); - return b_->CreateFPCast(fast_tanh, value->getType()); - } - return EmitLibdeviceMathCall("__nv_tanh", {value}, {prim_type}, prim_type); + // Emit a fast approximation of tanh instead of calling __nv_tanh. + // __nv_tanh is particularly bad because it contains branches, thus + // preventing LLVM's load-store vectorizer from working its magic across a + // function which contains tanh calls. + // + // This routine isn't numerically precise, but it's good enough for ML. + + // Upcast F16 to F32 if necessary. + llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType(); + llvm::Value* input = b_->CreateFPCast(value, type); + llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input); + return b_->CreateFPCast(fast_tanh, value->getType()); } llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 66aeb4efef4..6675dbd3f9e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -64,7 +64,7 @@ IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config, hlo_module_config_(hlo_module_config) { b_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config.debug_options() - .xla_enable_fast_math())); + .xla_gpu_enable_fast_math())); } Status IrEmitter::DefaultAction(HloInstruction* hlo) { diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index cf44458a2ed..ff4ae1f9ef2 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -180,7 +180,7 @@ std::unique_ptr GetTargetMachine( TargetOptions target_options = InitTargetOptionsFromCodeGenFlags(); llvm_ir::SetTargetOptions( /*fast_math_enabled=*/hlo_module_config.debug_options() - .xla_enable_fast_math(), + .xla_gpu_enable_fast_math(), &target_options); // Enable FMA synthesis. diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 76c9b6ab33b..d937123357e 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -72,6 +72,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/scatter_expander.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" @@ -130,8 +131,12 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) { } // Runs optimization passes on the given HLO module. +// +// It takes a compiler pointer, as passes may compile and execute HLOs on the +// fly for cuDNN verification or other purposes. Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* device_allocator) { + DeviceMemoryAllocator* device_allocator, + Compiler* compiler) { { HloPassPipeline pipeline("optimization"); pipeline.AddInvariantChecker(); @@ -167,6 +172,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // elimination has to come after that pass. pipeline.AddPass(); + pipeline.AddPass(); + pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); @@ -245,8 +252,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // the gte(customcall, 0) would probably already be into a fusion node. We // can't simplify across HloComputation boundaries, so in this case we // wouldn't be able to simplify away the new_tuple bits. - pipeline.AddPass(stream_exec, - device_allocator); + pipeline.AddPass( + stream_exec, device_allocator, compiler); // Clean up new_tuple described above. pipeline.AddPass(); @@ -492,11 +499,15 @@ NVPTXCompiler::NVPTXCompiler() StatusOr> NVPTXCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) { + // We dump the post-optimization HLO in RunBackend so no need to dump it here. + VLOG(2) << "*** HLO Before Optimization"; + XLA_VLOG_LINES(2, module->ToString()); + XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunHloPasses"); tracing::ScopedActivity activity("HLO Transforms", module->name(), /*is_expensive=*/true); TF_RETURN_IF_ERROR( - OptimizeHloModule(module.get(), stream_exec, device_allocator)); + OptimizeHloModule(module.get(), stream_exec, device_allocator, this)); return std::move(module); } @@ -548,6 +559,7 @@ StatusOr> NVPTXCompiler::RunBackend( // include headers, so no need for us to print them ourselves. XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString()); XLA_VLOG_LINES(2, buffer_assignment->ToString()); + VLOG(2) << "*** HLO After Optimization"; XLA_VLOG_LINES(2, module->ToString()); const string xla_dump_optimized_hlo_proto_to = module->config().debug_options().xla_dump_optimized_hlo_proto_to(); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 90d2be118d9..858992a3264 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -174,6 +174,29 @@ StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers)); } +StatusOr MakeMapHlo( + tensorflow::gtl::ArraySlice operands, + HloComputation* map_computation) { + CHECK(!operands.empty()) << "Map Hlo requires at least one operand."; + HloComputation* computation = operands.front()->parent(); + std::vector operand_shapes; + int64 max_operand_rank = 0; + for (const HloInstruction* operand : operands) { + CHECK_EQ(computation, operand->parent()); + operand_shapes.push_back(&operand->shape()); + max_operand_rank = + std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); + } + std::vector map_dims(max_operand_rank); + std::iota(map_dims.begin(), map_dims.end(), 0); + TF_ASSIGN_OR_RETURN( + Shape map_shape, + ShapeInference::InferMapShape( + operand_shapes, map_computation->ComputeProgramShape(), map_dims)); + return computation->AddInstruction( + HloInstruction::CreateMap(map_shape, operands, map_computation)); +} + StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n) { CHECK_GT(n, 0); @@ -251,6 +274,38 @@ StatusOr ElideDegenerateDims(HloInstruction* operand, return MakeReshapeHlo(output_shape, operand); } +StatusOr InsertDegenerateDims( + HloInstruction* operand, ArraySlice dims_to_insert) { + CHECK(c_is_sorted(dims_to_insert)); + + const Shape& operand_shape = operand->shape(); + int64 output_shape_rank = + operand_shape.dimensions_size() + dims_to_insert.size(); + for (auto dim_to_insert : dims_to_insert) { + CHECK_LT(dim_to_insert, output_shape_rank); + } + + std::vector output_shape_dim_bounds; + output_shape_dim_bounds.reserve(output_shape_rank); + int64 operand_dims_idx = 0; + int64 dims_to_insert_idx = 0; + for (int64 i = 0; i < output_shape_rank; ++i) { + if (dims_to_insert_idx < dims_to_insert.size() && + i == dims_to_insert[dims_to_insert_idx]) { + output_shape_dim_bounds.push_back(1); + ++dims_to_insert_idx; + } else { + output_shape_dim_bounds.push_back( + operand_shape.dimensions(operand_dims_idx)); + ++operand_dims_idx; + } + } + + Shape output_shape = ShapeUtil::MakeShape(operand_shape.element_type(), + output_shape_dim_bounds); + return MakeReshapeHlo(output_shape, operand); +} + StatusOr PadVectorWithZeros(HloInstruction* operand, int64 zeros_to_prepend, int64 zeros_to_append) { diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 49b1402d689..5ff8946fb09 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -102,6 +102,12 @@ StatusOr MakeConcatHlo( StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dim_numbers); +// Creates a Map HLO instruction and adds it to the computation containing the +// operands. All operands must be in the same computation. +StatusOr MakeMapHlo( + tensorflow::gtl::ArraySlice operands, + HloComputation* map_computation); + // ----------------------------------------------------------------------------- // Some other miscellaneous helpers to generate common HLO patterns. All of // these add all the instructions they generate into the computation containing @@ -144,6 +150,16 @@ StatusOr ExpandFirstDimIntoNDims( StatusOr ElideDegenerateDims( HloInstruction* operand, tensorflow::gtl::ArraySlice dims_to_elide); +// Inserts (via reshape) a set of degenerate dimensions (dimensions containing +// exactly one element), `dims_to_insert` into `operand`. The dimensions in +// `dims_to_insert` refer to the dimensions in the result, and hence should be +// less than the rank of the result. Also, `dims_to_insert` must be sorted. +// +// For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is +// {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34]. +StatusOr InsertDegenerateDims( + HloInstruction* operand, tensorflow::gtl::ArraySlice dims_to_insert); + // Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the // front and `zeros_to_append` zeros in the back. StatusOr PadVectorWithZeros(HloInstruction* operand, diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index 71b44507cc7..8e0d38b6a63 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -143,8 +143,47 @@ TokKind HloLexer::LexToken() { return TokKind::kLparen; case ')': return TokKind::kRparen; - case '/': - return LexComment(); + case '/': { + if (PeekCurrentChar() == '*') { + // This is the start of a /*...*/ delimited comment. Save the current + // location in case the comment is unterminated so the error message + // will point to the beginning of the comment. + const char* comment_start = current_ptr_; + current_ptr_++; + // Advance until '*/' is found. + while (true) { + int current = GetNextChar(); + if (current == '*' && PeekCurrentChar() == '/') { + // End of comment. + current_ptr_++; + break; + } + if (current == kEOF) { + // Unterminated comment. + current_ptr_ = comment_start; + return TokKind::kError; + } + } + // Return no token for the comment. Keep lexing. + continue; + } else if (PeekCurrentChar() == '/') { + // This is the start of a '//' delimited comment. Throw away + // everything until end of line or file. The end-of-line character(s) + // are left unlexed in the buffer which is harmless because these are + // skipped later by the lexer. This approach enables support for + // different end-of-line encodings. + while (true) { + int current = PeekCurrentChar(); + if (current == kEOF || current == '\n' || current == '\r') { + break; + } + current_ptr_++; + } + continue; + } + // A lone '/' is an error. + return TokKind::kError; + } case '"': return LexString(); } @@ -357,16 +396,6 @@ tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const { return StringPieceFromPointers(start, end); } -TokKind HloLexer::LexComment() { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); - static LazyRE2 comment_pattern = {R"(\/\*.*?\*\/)"}; - if (RE2::Consume(&consumable, *comment_pattern)) { - current_ptr_ = consumable.begin(); - return TokKind::kComment; - } - return TokKind::kError; -} - // Lexes quoted string with escaping characters. If matched, the quoted string // will be unescaped and stored to str_val_. TokKind HloLexer::LexString() { @@ -412,8 +441,6 @@ string TokKindToString(TokKind kind) { return "kRparen"; case TokKind::kArrow: return "kArrow"; - case TokKind::kComment: - return "kComment"; case TokKind::kw_HloModule: return "kw_HloModule"; case TokKind::kw_ENTRY: diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index ceb674f25e9..003ac34ace5 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -105,7 +105,6 @@ class HloLexer { TokKind LexShape(); TokKind LexConstant(); TokKind LexNumberOrPattern(); - TokKind LexComment(); TokKind LexString(); const tensorflow::StringPiece buf_; diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index b57c940238f..c577b4359aa 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -231,6 +231,7 @@ HLO_MATCHER(Tanh); HLO_MATCHER(Trace); HLO_MATCHER(Transpose); HLO_MATCHER(Tuple); +HLO_MATCHER(TupleSelect); HLO_MATCHER(While); // The special cases below let you check additional information about the diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 84f2d3f5fbc..1b256cd00e6 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -166,7 +166,7 @@ class HloModuleGroupMetadata { // // Precondition: IsCompanionWhile(instruction) is true. const std::unordered_set& Companions( - HloInstruction* instruction) const { + const HloInstruction* instruction) const { CHECK_EQ(companion_set_index_.count(instruction), 1); return companion_set(companion_set_index_.at(instruction)); } @@ -243,7 +243,7 @@ class HloModuleGroupMetadata { companion_sets_; // Map from each companion while instruction to the index into companion_set_. - tensorflow::gtl::FlatMap companion_set_index_; + tensorflow::gtl::FlatMap companion_set_index_; // Map from computation to the instruction using it (a kWhile, kConditional). tensorflow::gtl::FlatMap diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 9fd0ade1531..0dc56761482 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -37,24 +38,38 @@ namespace xla { std::vector HloModuleGroupUtil::GlobalPredecessors( HloInstruction* instruction) { - std::vector predecessors; + std::vector + predecessors; // Use a vector to avoid non-determinism. + tensorflow::gtl::FlatSet unique; - // Adds to the unique predecessors list and also add companion instructions - // if the given predecessor has those. + // Adds to the unique predecessors list; if the predecessors is a companion + // instruction, also add companion instructions; if the predecessors is a + // cross-module all-reduce, also add the all-reduce instructions in the same + // group. auto add_unique_predecessor = [&](HloInstruction* predecessor) { - if (std::find(predecessors.begin(), predecessors.end(), predecessor) != - predecessors.end()) { + if (unique.find(predecessor) != unique.end()) { return; } - if (!metadata_.IsCompanionInstruction(predecessor)) { - predecessors.push_back(predecessor); + if (metadata_.IsCompanionInstruction(predecessor)) { + for (HloInstruction* instr : metadata_.Companions(predecessor)) { + if (unique.insert(instr).second) { + predecessors.push_back(instr); + } + } return; } - for (HloInstruction* companion : metadata_.Companions(predecessor)) { - predecessors.push_back(companion); + if (predecessor->IsCrossModuleAllReduce()) { + for (HloInstruction* instr : + metadata_.GetAllReduceGroup(*predecessor->all_reduce_id())) { + if (unique.insert(instr).second) { + predecessors.push_back(instr); + } + } + return; } + unique.insert(predecessor); + predecessors.push_back(predecessor); }; - // If the given instruction is a companion instruction, we need to find the // predecessors of all of its companion instructions. If the instruction is an // all-reduce, we need to find the predecessors of all the peer all-reduce @@ -98,22 +113,37 @@ std::vector HloModuleGroupUtil::GlobalPredecessors( std::vector HloModuleGroupUtil::GlobalSuccessors( HloInstruction* instruction) { - std::vector successors; + std::vector + successors; // Use a vector to avoid non-determinism. + tensorflow::gtl::FlatSet unique; - // Adds to the unique successors list and also add companion instructions - // if the given successor has those. + // Adds to the unique successors list; if the successor is a companion + // instruction, also add companion instructions; if the successor is a + // cross-module all-reduce, also add the all-reduce instructions in the same + // group. auto add_unique_successor = [&](HloInstruction* successor) { - if (std::find(successors.begin(), successors.end(), successor) != - successors.end()) { + if (unique.find(successor) != unique.end()) { return; } - if (!metadata_.IsCompanionInstruction(successor)) { - successors.push_back(successor); + if (metadata_.IsCompanionInstruction(successor)) { + for (HloInstruction* instr : metadata_.Companions(successor)) { + if (unique.insert(instr).second) { + successors.push_back(instr); + } + } return; } - for (HloInstruction* companion : metadata_.Companions(successor)) { - successors.push_back(companion); + if (successor->IsCrossModuleAllReduce()) { + for (HloInstruction* instr : + metadata_.GetAllReduceGroup(*successor->all_reduce_id())) { + if (unique.insert(instr).second) { + successors.push_back(instr); + } + } + return; } + unique.insert(successor); + successors.push_back(successor); }; // If the given instruction is a companion instruction, we need to find the diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 2a8c6ecd924..4b3cd99dc06 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -1824,7 +1824,6 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, break; } case TokKind::kComma: - case TokKind::kComment: // Skip. lexer_.Lex(); break; diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 4cd21841f4c..5990a3d4784 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1560,6 +1560,81 @@ ENTRY consts { "last"); } +TEST_F(HloParserTest, Comments) { + const string original = R"(/* module description. */ +HloModule comments: + +ENTRY /*comment*/ c1 { + /* blah */ + ROOT const1 = /*foo*/f32[1]{0} constant({12345 /*bar*/}) + /* comment */ +} + +/* something else */ + +)"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + +TEST_F(HloParserTest, MultilineComments) { + const string original = R"(HloModule multiline_comment: +ENTRY c1 { + /* + ROOT foo = f32[1]{0} constant({12345}) + */ + ROOT const1 = f32[1]{0} constant({12345}) +/* +a +b +c +d + +*/ +})"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + +TEST_F(HloParserTest, UnterminatedComment) { + const string original = R"(HloModule unterminated_comment: +ENTRY c1 { +/* unterminated + ROOT const1 = f32[1]{0} constant({12345}) +})"; + // Verify that the error message points to the beginning of the unterminated + // comment. + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "/* unterminated\n^"); +} + +TEST_F(HloParserTest, SlashSlashComments) { + const string original = R"(HloModule slash_slash_comment: +// Garbage +ENTRY c1 { + // Foo bar + ROOT const1 = f32[1]{0} constant({12345}) // Something else +})"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + +TEST_F(HloParserTest, SlashSlashCommentMsDosEolFormat) { + const string original = + "HloModule slash_slash_comment:\r\n// Garbage\r\nENTRY c1 {\r\n// Foo " + "bar\r\nROOT const1 = f32[1]{0} constant({12345}) // Something else\r\n}"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + +TEST_F(HloParserTest, SlashSlashCommentMacEolFormat) { + const string original = + "HloModule slash_slash_comment:\r// Garbage\rENTRY c1 {\r// Foo " + "bar\rROOT const1 = f32[1]{0} constant({12345}) // Something else\r}"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + TEST_F(HloParserTest, MultipleEntries) { const string original = R"(HloModule multiple_entries: ENTRY c1 { diff --git a/tensorflow/compiler/xla/service/hlo_token.h b/tensorflow/compiler/xla/service/hlo_token.h index 533429608bc..4458c251dee 100644 --- a/tensorflow/compiler/xla/service/hlo_token.h +++ b/tensorflow/compiler/xla/service/hlo_token.h @@ -44,7 +44,6 @@ enum class TokKind { kRparen, // ( ) kArrow, // -> - kComment, // /*xxx*/ // Keywords kw_HloModule, diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index 9b109022fbf..db6b910b32f 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -104,7 +104,7 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { } // No "synchronize all activity" implemented for this platform at the moment. - bool SynchronizeAllActivity() override { return false; } + bool SynchronizeAllActivity() override { return true; } bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override { return false; } diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc new file mode 100644 index 00000000000..45ca731153b --- /dev/null +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -0,0 +1,350 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/scatter_expander.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +using tensorflow::gtl::ArraySlice; + +// Transposes the given scatter_indices such that the index_vector_dim becomes +// the most-minor dimension. +static StatusOr TransposeIndexVectorDimToLast( + HloInstruction* scatter_indices, int64 index_vector_dim) { + const Shape& scatter_indices_shape = scatter_indices->shape(); + + if (scatter_indices_shape.dimensions_size() == index_vector_dim) { + return scatter_indices; + } + + if (index_vector_dim == (scatter_indices_shape.dimensions_size() - 1)) { + return scatter_indices; + } + + std::vector permutation; + permutation.reserve(scatter_indices_shape.dimensions_size()); + for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { + if (i != index_vector_dim) { + permutation.push_back(i); + } + } + permutation.push_back(index_vector_dim); + return MakeTransposeHlo(scatter_indices, permutation); +} + +// Canonicalizes the scatter_indices tensor in order to keep them uniform while +// performing the scatter operation. +static StatusOr CanonicalizeScatterIndices( + HloInstruction* scatter_indices, int64 index_vector_dim) { + // Transpose the non-index-vector dimensions to the front. + TF_ASSIGN_OR_RETURN( + HloInstruction * transposed_scatter_indices, + TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim)); + bool indices_are_scalar = + index_vector_dim == scatter_indices->shape().dimensions_size(); + + // The number of dimensions in scatter_indices that are index dimensions. + const int64 index_dims_in_scatter_indices = indices_are_scalar ? 0 : 1; + + // If there is only one index (i.e. scatter_indices has rank 1 and this + // scatter is really just a dynamic update slice) add a leading degenerate + // dimension for uniformity. Otherwise create a "collapsed" leading dimension + // that subsumes all of the non-index-vector dimensions. + const Shape& shape = transposed_scatter_indices->shape(); + if (shape.dimensions_size() == index_dims_in_scatter_indices) { + return PrependDegenerateDims(transposed_scatter_indices, 1); + } else { + // Collapse all but the dimensions (0 or 1) in scatter_indices containing + // the index vectors. + return CollapseFirstNDims( + transposed_scatter_indices, + shape.dimensions_size() - index_dims_in_scatter_indices); + } +} + +// Permutes the `updates` tensor such that all the scatter dims appear in the +// major dimensions and all the window dimensions appear in the minor +// dimensions. +static StatusOr PermuteScatterAndWindowDims( + HloInstruction* updates, ArraySlice update_window_dims) { + std::vector permutation; + const int64 updates_rank = ShapeUtil::Rank(updates->shape()); + permutation.reserve(updates_rank); + + for (int64 i = 0; i < updates_rank; ++i) { + bool is_scatter_dim = !c_binary_search(update_window_dims, i); + if (is_scatter_dim) { + permutation.push_back(i); + } + } + for (auto window_dim : update_window_dims) { + permutation.push_back(window_dim); + } + + return MakeTransposeHlo(updates, permutation); +} + +// Expands or contracts the scatter indices in the updates tensor. +static StatusOr AdjustScatterDims( + const Shape& scatter_indices_shape, HloInstruction* updates, + int64 index_vector_dim) { + int64 num_scatter_dims = scatter_indices_shape.dimensions_size(); + if (index_vector_dim < scatter_indices_shape.dimensions_size()) { + --num_scatter_dims; + } + if (num_scatter_dims == 0) { + // If there are no scatter dims, this must be a dynamic-update-slice kind of + // scatter. In this case, we prepend a degenerate dimension to work + // uniformly in the while loop. + return PrependDegenerateDims(updates, 1); + } + return CollapseFirstNDims(updates, num_scatter_dims); +} + +// Expands an index vector from the scatter_indices tensor into a vector that +// can be used to dynamic-update-slice to perform the scatter update. +static StatusOr ExpandIndexVectorIntoOperandSpace( + HloInstruction* index_vector, const ScatterDimensionNumbers& dim_numbers, + int64 operand_rank) { + HloComputation* computation = index_vector->parent(); + const Shape& index_shape = index_vector->shape(); + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1}))); + + // We extract out individual components from the smaller index and concatenate + // them (interspersing zeros as needed) into the larger index. + std::vector expanded_index_components; + + for (int i = 0; i < operand_rank; i++) { + int64 index_vector_dim_index = + FindIndex(dim_numbers.scatter_dims_to_operand_dims(), i); + if (index_vector_dim_index != + dim_numbers.scatter_dims_to_operand_dims_size()) { + TF_ASSIGN_OR_RETURN( + HloInstruction * component_to_concat, + MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index}, + /*limit_indices=*/{index_vector_dim_index + 1}, + /*strides=*/{1})); + expanded_index_components.push_back(component_to_concat); + } else { + expanded_index_components.push_back(zero); + } + } + + return MakeConcatHlo(expanded_index_components, /*dimension=*/0); +} + +// Body of the while loop that performs the scatter operation using other HLOs. +static StatusOr> ScatterLoopBody( + HloInstruction* scatter, HloInstruction* induction_var, + const std::vector& loop_state) { + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + CHECK_EQ(loop_state.size(), 3); + HloInstruction* operand = loop_state[0]; + HloInstruction* scatter_indices = loop_state[1]; + HloInstruction* updates = loop_state[2]; + + bool has_scalar_indices = scatter_indices->shape().dimensions_size() == 1; + CHECK_EQ(has_scalar_indices, + dim_numbers.index_vector_dim() == + scatter->operand(1)->shape().dimensions_size()); + + // Build a vector form of the induction variable of the while loop. + TF_ASSIGN_OR_RETURN( + HloInstruction * induction_var_as_vector, + MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{}, + /*result_shape_bounds=*/{1})); + + // Pick the index to scatter from scatter_indices based on the induction_var + // and transform that to an index into the `operand` space. + HloInstruction* index_vector; + if (has_scalar_indices) { + TF_ASSIGN_OR_RETURN( + index_vector, + MakeDynamicSliceHlo(scatter_indices, induction_var_as_vector, {1})); + } else { + TF_ASSIGN_OR_RETURN( + HloInstruction * index_into_scatter_indices, + PadVectorWithZeros(induction_var_as_vector, + /*zeros_to_prepend=*/0, /*zeros_to_append=*/1)); + int index_vector_size = scatter_indices->shape().dimensions(1); + TF_ASSIGN_OR_RETURN( + HloInstruction * index_vector_2d, + MakeDynamicSliceHlo(scatter_indices, index_into_scatter_indices, + {1, index_vector_size})); + TF_ASSIGN_OR_RETURN(index_vector, + ElideDegenerateDims(index_vector_2d, {0})); + } + TF_ASSIGN_OR_RETURN( + HloInstruction * scatter_slice_start, + ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers, + operand->shape().dimensions_size())); + + // Extract the slice to be used to update from `updates` tensor for the + // induction_var corresponding to this iteration of the while loop. + TF_ASSIGN_OR_RETURN( + HloInstruction * index_into_updates, + PadVectorWithZeros( + induction_var_as_vector, /*zeros_to_prepend=*/0, + /*zeros_to_append=*/updates->shape().dimensions_size() - 1)); + std::vector update_slice_bounds(updates->shape().dimensions().begin(), + updates->shape().dimensions().end()); + update_slice_bounds[0] = 1; + TF_ASSIGN_OR_RETURN( + HloInstruction * update_slice, + MakeDynamicSliceHlo(updates, index_into_updates, update_slice_bounds)); + TF_ASSIGN_OR_RETURN(HloInstruction * update_slice_for_scatter, + ElideDegenerateDims(update_slice, {0})); + TF_ASSIGN_OR_RETURN( + HloInstruction * update_slice_with_dims_inserted, + InsertDegenerateDims(update_slice_for_scatter, + AsInt64Slice(dim_numbers.inserted_window_dims()))); + + // Extact the slice to update from `operand` tensor. + const Shape& update_slice_shape = update_slice_with_dims_inserted->shape(); + TF_ASSIGN_OR_RETURN( + HloInstruction * operand_slice_to_update, + MakeDynamicSliceHlo(operand, scatter_slice_start, + AsInt64Slice(update_slice_shape.dimensions()))); + + // Compute the new value for the slice to be updated in `operand` tensor by + // combining the existing value and the update value using the update + // computation. + TF_ASSIGN_OR_RETURN( + HloInstruction * updated_operand_slice, + MakeMapHlo({operand_slice_to_update, update_slice_with_dims_inserted}, + scatter->to_apply())); + + // Write the updated value of the slice into `operand` tensor. + TF_ASSIGN_OR_RETURN(HloInstruction * updated_operand, + MakeDynamicUpdateSliceHlo(operand, updated_operand_slice, + scatter_slice_start)); + + return StatusOr>{ + {updated_operand, scatter_indices, updates}}; +} + +// High Level Algorithm. +// +// 1. Canonicalize the scatter_indices tensor such that it has rank 2, where +// each row is an index into the operand. +// 2. Canonicalize the updates tensor such that is has rank `num_window_dims+1` +// and the scatter dim is the most-major dimension. +// 3. Iterate over the set of indices in the canonicalized scatter_indices +// tensor using a while loop, updating the operand for each such index. Each +// iteration of this while loop performs the following: +// a. Pick the index from scatter_indices for this iteration. +// b. Transfrom this index into an index into the operand space. +// c. Extract the slice to be used to update from the updates tensor. +// d. Extract the slice to update from the operand tensor. +// e. Compute the new value for the slice to update by combining the slices +// from c. and d. using the update_computation of scatter. +// f. Write the updated value of the slice into the operand tensor. + +StatusOr ScatterExpander::ExpandScatter( + HloInstruction* scatter) { + HloInstruction* operand = scatter->mutable_operand(0); + HloInstruction* scatter_indices = scatter->mutable_operand(1); + HloInstruction* updates = scatter->mutable_operand(2); + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + + // If the updates tensor is empty, there is no need to update the operand. We + // can return the operand as is. + if (ShapeUtil::IsZeroElementArray(updates->shape())) { + return operand; + } + + // Compute the trip count for the while loop to be used for scatter. This + // should be the number of indices we should scatter into the operand. + const Shape& scatter_indices_shape = scatter_indices->shape(); + int64 scatter_loop_trip_count = 1; + for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { + if (i != dim_numbers.index_vector_dim()) { + scatter_loop_trip_count *= scatter_indices_shape.dimensions(i); + } + } + if (!IsInt32(scatter_loop_trip_count)) { + return Unimplemented( + "Scatter operations with more than 2147483647 scatter indices are not " + "supported. This error occurred for %s.", + scatter->ToString().c_str()); + } + + // Canonicalize the scatter_indices, after which the size of its most-major + // dimension must be same as the while loop trip count. + TF_ASSIGN_OR_RETURN(HloInstruction * canonical_scatter_indices, + CanonicalizeScatterIndices( + scatter_indices, dim_numbers.index_vector_dim())); + CHECK_EQ(scatter_loop_trip_count, + canonical_scatter_indices->shape().dimensions(0)); + + // Canonicalize the updates, after which the size of its most-major dimension + // must be same as the while loop trip count. + TF_ASSIGN_OR_RETURN( + HloInstruction * canonical_updates, + PermuteScatterAndWindowDims( + updates, AsInt64Slice(dim_numbers.update_window_dims()))); + TF_ASSIGN_OR_RETURN( + HloInstruction * adjusted_canonical_updates, + AdjustScatterDims(scatter_indices->shape(), canonical_updates, + dim_numbers.index_vector_dim())); + CHECK_EQ(scatter_loop_trip_count, + adjusted_canonical_updates->shape().dimensions(0)); + + // The while loop that implements the scatter operation. + StatusOr> scatter_loop_result_status = + WhileUtil::MakeCountedLoop( + scatter->parent(), scatter_loop_trip_count, + {operand, canonical_scatter_indices, adjusted_canonical_updates}, + [&](HloInstruction* induction_var, + const std::vector& loop_state) { + return ScatterLoopBody(scatter, induction_var, loop_state); + }); + TF_ASSIGN_OR_RETURN(std::vector scatter_loop_result, + scatter_loop_result_status); + return scatter_loop_result.front(); +} + +StatusOr ScatterExpander::Run(HloModule* module) { + std::vector scatter_instrs; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + for (HloInstruction* instr : computation->instructions()) { + if (instr->opcode() == HloOpcode::kScatter) { + scatter_instrs.push_back(instr); + } + } + } + + for (auto instr : scatter_instrs) { + TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandScatter(instr)); + TF_RETURN_IF_ERROR( + instr->parent()->ReplaceInstruction(instr, expanded_root)); + } + + return !scatter_instrs.empty(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h new file mode 100644 index 00000000000..8f735e877d2 --- /dev/null +++ b/tensorflow/compiler/xla/service/scatter_expander.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +class ScatterExpander : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { return "scatter_expander"; } + StatusOr Run(HloModule* module) override; + + private: + StatusOr ExpandScatter(HloInstruction* scatter); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 34869cc5078..b69c346f1e6 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -1014,12 +1014,13 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { } /* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) { + if (!IsTuple(shape)) { + return 1; + } int64 count = 0; - ForEachSubshape(shape, [&](const Shape&, const ShapeIndex& index) { - if (IsLeafIndex(shape, index)) { - ++count; - } - }); + for (const Shape& subshape : shape.tuple_shapes()) { + count += GetLeafCount(subshape); + } return count; } diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 42d52aee780..0f8cffd466c 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -709,6 +709,21 @@ xla_test( ], ) +xla_test( + name = "scatter_test", + srcs = ["scatter_test.cc"], + deps = [ + ":client_library_test_base", + ":hlo_test_base", + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + # Repeat dot_operation_runtime_test with single-threaded eigen. xla_test( name = "dot_operation_single_threaded_runtime_test", @@ -2061,6 +2076,8 @@ tf_cc_test( xla_test( name = "test_utils_test", srcs = ["test_utils_test.cc"], + # There is nothing backend specific in this test, so just pick an arbitrary backend. + backends = ["cpu"], deps = [ ":local_client_test_base", ":test_utils", diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 4a6e8a31241..b04a3b105ca 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -74,8 +74,9 @@ class ClientLibraryTestBase : public ::testing::Test { string TestName() const; void SetFastMathDisabled(bool disabled) { - execution_options_.mutable_debug_options()->set_xla_enable_fast_math( - !disabled); + auto* opts = execution_options_.mutable_debug_options(); + opts->set_xla_cpu_enable_fast_math(!disabled); + opts->set_xla_gpu_enable_fast_math(!disabled); } void SetSeed(uint64 seed) { execution_options_.set_seed(seed); } diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 73edad89dc8..92c93f08b2e 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -1464,5 +1464,24 @@ ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] { EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); } +XLA_TEST_F(HloTestBase, ReduceWindowF16) { + const string hlo_string = R"( +HloModule reduce-window + +%identity.pad_to_reduce_window (param0: f16[], param1: f16[]) -> f16[] { + %param0 = f16[] parameter(0) + ROOT %param1 = f16[] parameter(1) +} + +ENTRY %reduce-window (parameter.0: f16[81,8], parameter.1: f16[]) -> f16[82,8] { + %parameter.0 = f16[81,8]{1,0} parameter(0) + %parameter.1 = f16[] parameter(1) + ROOT %reduce-window = f16[82,8]{1,0} reduce-window(f16[81,8]{1,0} %parameter.0, f16[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc new file mode 100644 index 00000000000..922d70b7526 --- /dev/null +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -0,0 +1,615 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +using tensorflow::gtl::nullopt; + +class ScatterTest : public HloTestBase { + protected: + void RunTest(const string& hlo_text, Literal* operand, + Literal* scatter_indices, Literal* updates) { + RunTest(hlo_text, {operand, scatter_indices, updates}); + } + + void RunTest(const string& hlo_text, + tensorflow::gtl::ArraySlice args) { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text, config)); + EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt)); + } +}; + +XLA_TEST_F(ScatterTest, TensorFlowScatterV1_Update) { + const string hlo_text = R"( +HloModule TensorFlowScatterV1 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) { + const char* hlo_text = R"( +HloModule TensorFlowScatterV2 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[3,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={0}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 30}, {40, 60}, {70, 90}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) { + const string hlo_text = R"( +HloModule TensorFlowScatter_Add + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) { + const string hlo_text = R"( +HloModule TensorFlowScatter_Mul + +mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=mul_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_F32) { + const string hlo_text = R"( +HloModule TensorFlowScatter_F32 + +add_f32 (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(f32[] lhs, f32[] rhs) +} + +ENTRY main { + operand = f32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = f32[2,3] parameter(2) + ROOT scatter = f32[3,3] scatter(operand, indices, updates), + to_apply=add_f32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = LiteralUtil::CreateR2( + {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({2, 1}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_RepeatedIndices) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({1, 1}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_MultipleBatchDims) { + const char* hlo_text = R"( +HloModule TensorFlowScatterMultipleBatchDims + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,3,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=2 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + std::unique_ptr updates = LiteralUtil::CreateR3( + {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatterNd) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNd + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatterNd_NonDefaultIndexVectorDim) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNdNonDefaultIndexVectorDim + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, DynamicUpdateSlice) { + const char* hlo_text = R"( +HloModule DynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[1,1] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={0,1}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({1, 1}); + std::unique_ptr updates = LiteralUtil::CreateR2({{10}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, BatchDynamicUpdateSlice) { + const char* hlo_text = R"( +HloModule BatchDynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,1,1] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + std::unique_ptr updates = + LiteralUtil::CreateR3({{{10}}, {{20}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, ZeroDimBounds) { + const char* hlo_text = R"( +HloModule TensorFlowScatter_ZeroDimBounds + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,0] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,0] parameter(2) + ROOT scatter = s32[3,0] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = LiteralUtil::CreateR2({{}, {}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, NoUpdateWindowDims) { + const string hlo_text = R"( +HloModule Scatter_NoUpdateWindowDims + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3] parameter(0) + indices = s32[2,2,1] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=2 +} +)"; + std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20}, {30, 40}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, OutOfBoundsIndex) { + const string hlo_text = R"( +HloModule BatchDynamicSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + updates = s32[6,1,1]{2,1,0} parameter(2) + ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = LiteralUtil::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); + std::unique_ptr updates = LiteralUtil::CreateR3( + {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, OutOfBoundsUnsignedIndex) { + const string hlo_text = R"( +HloModule BatchDynamicSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = u32[6,2]{1,0} parameter(1) + updates = s32[6,1,1]{2,1,0} parameter(2) + ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = LiteralUtil::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); + std::unique_ptr updates = LiteralUtil::CreateR3( + {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, NegativeIndex) { + const string hlo_text = R"( +HloModule BatchDynamicSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + updates = s32[6,1,1]{2,1,0} parameter(2) + ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = LiteralUtil::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); + std::unique_ptr updates = LiteralUtil::CreateR3( + {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, OneScalarIndex) { + const char* hlo_text = R"( +HloModule OneScalarIndex + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[2,3,2]{2,1,0} parameter(0) + index = s32[] parameter(1) + updates = s32[1,3,2]{2,1,0} parameter(2) + ROOT scatter = s32[2,3,2]{2,1,0} scatter(operand, index, updates), + to_apply=update_s32, + update_window_dims={0,1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=0 +} +)"; + std::unique_ptr operand = LiteralUtil::CreateR3( + {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); + std::unique_ptr scatter_indices = LiteralUtil::CreateR0(1); + std::unique_ptr updates = + LiteralUtil::CreateR3({{{10, 20}, {30, 40}, {50, 60}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, ScalarUpdate) { + const char* hlo_text = R"( +HloModule ScalarUpdate + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[4]{0} parameter(0) + index = s32[] parameter(1) + updates = s32[] parameter(2) + ROOT scatter = s32[4]{0} scatter(operand, index, updates), + to_apply=update_s32, + update_window_dims={}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=0 +} +)"; + std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3, 4}); + std::unique_ptr scatter_indices = LiteralUtil::CreateR0(1); + std::unique_ptr updates = LiteralUtil::CreateR0(25); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, EmptyIndices) { + const string hlo_text = R"( +HloModule EmptyIndices + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3] parameter(0) + indices = s32[0] parameter(1) + updates = s32[0] parameter(2) + ROOT scatter = s32[3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3}); + std::unique_ptr scatter_indices = LiteralUtil::CreateR1({}); + std::unique_ptr updates = LiteralUtil::CreateR1({}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 26479370132..faeec657b66 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -208,16 +208,12 @@ bool NeedsInitValue(const HloUse& use) { // Generate random values that are constrained to the input_shape minus the // output_shape so as not to produce wrapping slices, for instance. -std::unique_ptr MakeRandomNonwrappingSliceIndex( - const Shape& input_shape, const Shape& slice_shape, - std::minstd_rand0* engine) { - const int64 rank = ShapeUtil::Rank(input_shape); - std::vector start_indices(rank); +std::unique_ptr MakeRandomIndex( + tensorflow::gtl::ArraySlice index_space, std::minstd_rand0* engine) { + std::vector start_indices(index_space.size()); if (engine != nullptr) { - for (int i = 0; i < rank; ++i) { - const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - - ShapeUtil::GetDimension(slice_shape, i); - std::uniform_int_distribution generator(0, upper_bound); + for (int i = 0; i < index_space.size(); ++i) { + std::uniform_int_distribution generator(0, index_space[i]); start_indices[i] = generator(*engine); } } @@ -267,37 +263,42 @@ std::vector FindConstrainedUses( StatusOr> CreateLiteralForConstrainedUses( const tensorflow::gtl::ArraySlice constrained_uses, const HloInstruction& param, std::minstd_rand0* engine) { - HloInstruction* needs_index = nullptr; - HloInstruction* needs_constant = nullptr; + std::vector index_space; + bool needs_constant = false; ConstantType constant_type = ConstantType::kUnknown; for (HloInstruction* use : constrained_uses) { switch (use->opcode()) { case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: - if (needs_index != nullptr) { - auto needs_index_shape = needs_index->shape(); - auto use_shape = use->shape(); - if (needs_index->opcode() == HloOpcode::kDynamicSlice) { - needs_index_shape = needs_index->operand(0)->shape(); + case HloOpcode::kDynamicUpdateSlice: { + const Shape& indexed_shape = use->operand(0)->shape(); + const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice + ? use->shape() + : use->operand(1)->shape(); + const int64 rank = ShapeUtil::Rank(indexed_shape); + if (!index_space.empty()) { + TF_RET_CHECK(rank == index_space.size()); + for (int64 i = 0; i < rank; ++i) { + index_space[i] = std::min( + index_space[i], ShapeUtil::GetDimension(indexed_shape, i) - + ShapeUtil::GetDimension(slice_shape, i)); } - if (use->opcode() == HloOpcode::kDynamicSlice) { - use_shape = use->operand(0)->shape(); - } - if (!ShapeUtil::Equal(needs_index_shape, use_shape)) { - return Unimplemented( - "Conflicting operand generation slice index constraints\n"); + } else { + index_space.resize(rank); + for (int64 i = 0; i < rank; ++i) { + index_space[i] = ShapeUtil::GetDimension(indexed_shape, i) - + ShapeUtil::GetDimension(slice_shape, i); } } - needs_index = use; break; + } case HloOpcode::kReduce: case HloOpcode::kReduceWindow: - needs_constant = use; + needs_constant = true; constant_type = GetInitValue(*use->to_apply()); break; case HloOpcode::kSelectAndScatter: - needs_constant = use; + needs_constant = true; constant_type = GetInitValue(*use->scatter()); break; @@ -307,16 +308,14 @@ StatusOr> CreateLiteralForConstrainedUses( use->ToString().c_str()); } } - if (needs_index != nullptr && needs_constant != nullptr) { + if (!index_space.empty() && needs_constant) { return Unimplemented( - "Conflicting operand generation constraints.\nNeeds index: %s\nNeeds " - "constant: %s\n", - needs_index->ToString().c_str(), needs_constant->ToString().c_str()); + "Conflicting operand generation constraints. Dynamically indexes a " + "shape and is the init value of a reduction."); } - if (needs_index != nullptr) { - return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(), - needs_index->shape(), engine); - } else if (needs_constant != nullptr) { + if (!index_space.empty()) { + return MakeRandomIndex(index_space, engine); + } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique(); @@ -356,8 +355,8 @@ StatusOr>> MakeFakeArguments( auto engine = pseudo_random ? MakeUnique() : nullptr; std::vector> arguments(params.size()); for (int i = 0; i < params.size(); ++i) { - TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument( - *dataflow, *params[i], engine.get())); + arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], engine.get()) + .ValueOrDie(); } return std::move(arguments); } diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index a2f0338e259..64d9e2031eb 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -72,5 +72,60 @@ XLA_TEST_F(TestUtilsTest, Token) { TF_ASSERT_OK(MakeFakeArguments(module.get()).status()); } +XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { + auto module = ParseHloString( + R"(HloModule index_space_module + + ENTRY IndexSpace { + index_param = s32[3]{0} parameter(0) + array_param.1 = f32[123,4,789]{0,1,2} parameter(1) + array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) + dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param), dynamic_slice_sizes={1,2,3} + ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2} + })") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 3); + const Literal& index_arg = *args[0]; + + EXPECT_EQ(index_arg.Get({0}), 0); + + EXPECT_GE(index_arg.Get({1}), 0); + EXPECT_LE(index_arg.Get({1}), 2); + + EXPECT_GE(index_arg.Get({2}), 0); + EXPECT_LE(index_arg.Get({2}), 3); +} + +XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { + auto module = ParseHloString( + R"(HloModule index_space_module + + ENTRY IndexSpace { + index_param = s32[3]{0} parameter(0) + array_param.1 = f32[123,4,789]{0,1,2} parameter(1) + array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) + update_param.1 = f32[1,2,3]{0,1,2} parameter(3) + update_param.2 = f32[3,2,2]{0,1,2} parameter(4) + + dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param) + ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param) + })") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 5); + const Literal& index_arg = *args[0]; + + EXPECT_EQ(index_arg.Get({0}), 0); + + EXPECT_GE(index_arg.Get({1}), 0); + EXPECT_LE(index_arg.Get({1}), 2); + + EXPECT_GE(index_arg.Get({2}), 0); + EXPECT_LE(index_arg.Get({2}), 3); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 10c0adc6707..3b72eb17c60 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -104,15 +104,6 @@ message DebugOptions { // interpretation of this value is left to the backends. int32 xla_backend_optimization_level = 31; - // When true, "unsafe" mathematical optimizations are enabled. These - // transformations include but are not limited to: - // - // - Reducing the precision of operations (e.g. using an approximate sin - // function, or transforming x/y into x * (1/y)). - // - Assuming that operations never produce or consume NaN or +/- Inf. - // - Assuming that +0 and -0 are indistinguishable. - bool xla_enable_fast_math = 32; - // Embed the compiler IR as a string in the executable. bool xla_embed_ir_in_executable = 33; @@ -194,6 +185,16 @@ message DebugOptions { // Maximum kernel unroll factor for the GPU backend. int32 xla_gpu_max_kernel_unroll_factor = 98; + // When true, "unsafe" mathematical optimizations are enabled. These + // transformations include but are not limited to: + // + // - Reducing the precision of operations (e.g. using an approximate sin + // function, or transforming x/y into x * (1/y)). + // - Assuming that operations never produce or consume NaN or +/- Inf. + // - Assuming that +0 and -0 are indistinguishable. + bool xla_cpu_enable_fast_math = 99; + bool xla_gpu_enable_fast_math = 100; + // Extra options to pass to the compilation backend; specific interpretation // of these values is left to the backend. map xla_backend_extra_options = 500; diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc index 1790b4bc116..a25a641cdb4 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc @@ -218,11 +218,11 @@ class ToBigtableOp : public AsyncOpKernel { OP_REQUIRES_OK_ASYNC( ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset), done); - IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx); std::unique_ptr iterator; OP_REQUIRES_OK_ASYNC( ctx, - dataset->MakeIterator(&iter_ctx, "ToBigtableOpIterator", &iterator), + dataset->MakeIterator(IteratorContext(ctx), "ToBigtableOpIterator", + &iterator), done); int64 timestamp_int; @@ -245,9 +245,10 @@ class ToBigtableOp : public AsyncOpKernel { ::google::cloud::bigtable::BulkMutation mutation; // TODO(saeta): Make # of mutations configurable. for (uint64 i = 0; i < 100 && !end_of_sequence; ++i) { - OP_REQUIRES_OK_ASYNC( - ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence), - done); + OP_REQUIRES_OK_ASYNC(ctx, + iterator->GetNext(IteratorContext(ctx), + &components, &end_of_sequence), + done); if (!end_of_sequence) { OP_REQUIRES_OK_ASYNC( ctx, diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc index 9e49fa35db4..bd32672aa99 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc @@ -53,7 +53,7 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, BigtableTableResource* table, @@ -61,7 +61,7 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { std::vector columns, const DataTypeVector& output_types, std::vector output_shapes) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), input_(input), table_(table), column_families_(std::move(column_families)), @@ -80,8 +80,8 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr(new Iterator( - {this, strings::StrCat(prefix, "::BigtableLookupDataset")})); + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::BigtableLookup")})); } const DataTypeVector& output_dtypes() const override { @@ -96,6 +96,14 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { return "BigtableLookupDatasetOp::Dataset"; } + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented("%s does not support serialization", + DebugString()); + } + private: static ::google::cloud::bigtable::Filter MakeFilter( const std::vector& column_families, diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc index e960719614a..a803fdcb496 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc @@ -35,11 +35,13 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, string prefix) - : GraphDatasetBase(ctx), table_(table), prefix_(std::move(prefix)) { + : DatasetBase(DatasetContext(ctx)), + table_(table), + prefix_(std::move(prefix)) { table_->Ref(); } @@ -47,8 +49,8 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr(new Iterator( - {this, strings::StrCat(prefix, "::BigtablePrefixKeyDataset")})); + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::BigtablePrefixKey")})); } const DataTypeVector& output_dtypes() const override { @@ -68,6 +70,14 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel { BigtableTableResource* table() const { return table_; } + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented("%s does not support serialization", + DebugString()); + } + private: class Iterator : public BigtableReaderDatasetIterator { public: diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc index 96d3565d9b9..5cd0371c79f 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc @@ -39,11 +39,11 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, string start_key, string end_key) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), table_(table), start_key_(std::move(start_key)), end_key_(std::move(end_key)) { @@ -54,8 +54,8 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr(new Iterator( - {this, strings::StrCat(prefix, "::BigtableRangeKeyDataset")})); + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::BigtableRangeKey")})); } const DataTypeVector& output_dtypes() const override { @@ -75,6 +75,14 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel { BigtableTableResource* table() const { return table_; } + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented("%s does not support serialization", + DebugString()); + } + private: class Iterator : public BigtableReaderDatasetIterator { public: diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc index a1a63a975af..6928d9423c8 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc @@ -52,11 +52,11 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, string prefix, string start_key, string end_key) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), table_(table), key_range_(MakeMultiModeKeyRange( std::move(prefix), std::move(start_key), std::move(end_key))) { @@ -68,7 +68,7 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr(new Iterator( - {this, strings::StrCat(prefix, "::BigtableSampleKeyPairsDataset")})); + {this, strings::StrCat(prefix, "::BigtableSampleKeyPairs")})); } const DataTypeVector& output_dtypes() const override { @@ -87,6 +87,14 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel { return "BigtableSampleKeyPairsDatasetOp::Dataset"; } + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented("%s does not support serialization", + DebugString()); + } + private: static MultiModeKeyRange MakeMultiModeKeyRange(string prefix, string start_key, diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc index a5a47cfe2dc..a759fb50639 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc @@ -31,10 +31,10 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table) - : GraphDatasetBase(ctx), table_(table) { + : DatasetBase(DatasetContext(ctx)), table_(table) { table_->Ref(); } @@ -43,7 +43,7 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr(new Iterator( - {this, strings::StrCat(prefix, "::BigtableSampleKeysDataset")})); + {this, strings::StrCat(prefix, "::BigtableSampleKeys")})); } const DataTypeVector& output_dtypes() const override { @@ -63,6 +63,14 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel { BigtableTableResource* table() const { return table_; } + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented("%s does not support serialization", + DebugString()); + } + private: class Iterator : public DatasetIterator { public: diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc index 13cb8681679..78a920b0776 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc @@ -84,7 +84,7 @@ class BigtableScanDatasetOp : public DatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, string prefix, string start_key, string end_key, @@ -92,7 +92,7 @@ class BigtableScanDatasetOp : public DatasetOpKernel { std::vector columns, float probability, const DataTypeVector& output_types, std::vector output_shapes) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), table_(table), prefix_(std::move(prefix)), start_key_(std::move(start_key)), @@ -111,8 +111,8 @@ class BigtableScanDatasetOp : public DatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr(new Iterator( - {this, strings::StrCat(prefix, "::BigtableScanDataset")})); + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::BigtableScan")})); } const DataTypeVector& output_dtypes() const override { @@ -129,6 +129,14 @@ class BigtableScanDatasetOp : public DatasetOpKernel { BigtableTableResource* table() const { return table_; } + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented("%s does not support serialization", + DebugString()); + } + private: class Iterator : public BigtableReaderDatasetIterator { public: diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 401bec84a20..d9e7a0f4660 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -34,7 +34,9 @@ namespace tensorflow { +using boosted_trees::learner::LearnerConfig; using boosted_trees::learner::LearnerConfig_MultiClassStrategy; +using boosted_trees::learner::ObliviousSplitInfo; using boosted_trees::learner::SplitInfo; using boosted_trees::learner::stochastic::GradientStats; using boosted_trees::learner::stochastic::NodeStats; @@ -158,6 +160,11 @@ class BuildDenseInequalitySplitsOp : public OpKernel { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); + const Tensor* weak_learner_type_t; + OP_REQUIRES_OK(context, + context->input("weak_learner_type", &weak_learner_type_t)); + const int32 weak_learner_type = weak_learner_type_t->scalar()(); + // Find the number of unique partitions before we allocate the output. std::vector partition_boundaries; partition_boundaries.push_back(0); @@ -188,20 +195,59 @@ class BuildDenseInequalitySplitsOp : public OpKernel { tensorflow::TTypes::Vec output_partition_ids = output_partition_ids_t->vec(); - Tensor* gains_t = nullptr; - OP_REQUIRES_OK( - context, context->allocate_output("gains", TensorShape({num_elements}), - &gains_t)); + // For a normal tree, we output a split per partition. For an oblivious + // tree, we output one split for all partitions of the layer + int32 size_output = num_elements; + if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE && + num_elements > 0) { + size_output = 1; + } + Tensor* gains_t = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + "gains", TensorShape({size_output}), &gains_t)); tensorflow::TTypes::Vec gains = gains_t->vec(); Tensor* output_splits_t = nullptr; - OP_REQUIRES_OK(context, context->allocate_output( - "split_infos", TensorShape({num_elements}), - &output_splits_t)); + OP_REQUIRES_OK(context, context->allocate_output("split_infos", + TensorShape({size_output}), + &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + + if (num_elements == 0) { + return; + } SplitBuilderState state(context); + switch (weak_learner_type) { + case LearnerConfig::NORMAL_DECISION_TREE: { + ComputeNormalDecisionTree( + &state, normalizer_ratio, num_elements, partition_boundaries, + bucket_boundaries, partition_ids, bucket_ids, gradients_t, + hessians_t, &output_partition_ids, &gains, &output_splits); + break; + } + case LearnerConfig::OBLIVIOUS_DECISION_TREE: { + ComputeObliviousDecisionTree( + &state, normalizer_ratio, num_elements, partition_boundaries, + bucket_boundaries, partition_ids, bucket_ids, gradients_t, + hessians_t, &output_partition_ids, &gains, &output_splits); + break; + } + } + } + + private: + void ComputeNormalDecisionTree( + SplitBuilderState* state, const float normalizer_ratio, + const int num_elements, const std::vector& partition_boundaries, + const tensorflow::TTypes::ConstVec& bucket_boundaries, + const tensorflow::TTypes::ConstVec& partition_ids, + const tensorflow::TTypes::ConstMatrix& bucket_ids, + const Tensor* gradients_t, const Tensor* hessians_t, + tensorflow::TTypes::Vec* output_partition_ids, + tensorflow::TTypes::Vec* gains, + tensorflow::TTypes::Vec* output_splits) { for (int root_idx = 0; root_idx < num_elements; ++root_idx) { float best_gain = std::numeric_limits::lowest(); int start_index = partition_boundaries[root_idx]; @@ -213,7 +259,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel { GradientStats(*gradients_t, *hessians_t, bucket_idx); } root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state->ComputeNodeStats(root_gradient_stats); int32 best_bucket_idx = 0; NodeStats best_right_node_stats(0); NodeStats best_left_node_stats(0); @@ -223,10 +269,10 @@ class BuildDenseInequalitySplitsOp : public OpKernel { GradientStats g(*gradients_t, *hessians_t, bucket_idx); g *= normalizer_ratio; left_gradient_stats += g; - NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats); + NodeStats left_stats = state->ComputeNodeStats(left_gradient_stats); GradientStats right_gradient_stats = root_gradient_stats - left_gradient_stats; - NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats); + NodeStats right_stats = state->ComputeNodeStats(right_gradient_stats); if (left_stats.gain + right_stats.gain > best_gain) { best_gain = left_stats.gain + right_stats.gain; best_left_node_stats = left_stats; @@ -237,21 +283,125 @@ class BuildDenseInequalitySplitsOp : public OpKernel { SplitInfo split_info; auto* dense_split = split_info.mutable_split_node()->mutable_dense_float_binary_split(); - dense_split->set_feature_column(state.feature_column_group_id()); + dense_split->set_feature_column(state->feature_column_group_id()); dense_split->set_threshold( bucket_boundaries(bucket_ids(best_bucket_idx, 0))); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - state.FillLeaf(best_left_node_stats, left_child); - state.FillLeaf(best_right_node_stats, right_child); - split_info.SerializeToString(&output_splits(root_idx)); - gains(root_idx) = - best_gain - root_stats.gain - state.tree_complexity_regularization(); - output_partition_ids(root_idx) = partition_ids(start_index); + state->FillLeaf(best_left_node_stats, left_child); + state->FillLeaf(best_right_node_stats, right_child); + split_info.SerializeToString(&(*output_splits)(root_idx)); + (*gains)(root_idx) = + best_gain - root_stats.gain - state->tree_complexity_regularization(); + (*output_partition_ids)(root_idx) = partition_ids(start_index); } } + void ComputeObliviousDecisionTree( + SplitBuilderState* state, const float normalizer_ratio, + const int num_elements, const std::vector& partition_boundaries, + const tensorflow::TTypes::ConstVec& bucket_boundaries, + const tensorflow::TTypes::ConstVec& partition_ids, + const tensorflow::TTypes::ConstMatrix& bucket_ids, + const Tensor* gradients_t, const Tensor* hessians_t, + tensorflow::TTypes::Vec* output_partition_ids, + tensorflow::TTypes::Vec* gains, + tensorflow::TTypes::Vec* output_splits) { + // Holds the root stats per each node to be split. + std::vector current_layer_stats; + current_layer_stats.reserve(num_elements); + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + const int start_index = partition_boundaries[root_idx]; + const int end_index = partition_boundaries[root_idx + 1]; + GradientStats root_gradient_stats; + for (int64 bucket_idx = start_index; bucket_idx < end_index; + ++bucket_idx) { + root_gradient_stats += + GradientStats(*gradients_t, *hessians_t, bucket_idx); + } + root_gradient_stats *= normalizer_ratio; + current_layer_stats.push_back(root_gradient_stats); + } + + float best_gain = std::numeric_limits::lowest(); + int64 best_bucket_idx = 0; + std::vector best_right_node_stats(num_elements, NodeStats(0)); + std::vector best_left_node_stats(num_elements, NodeStats(0)); + std::vector current_left_node_stats(num_elements, NodeStats(0)); + std::vector current_right_node_stats(num_elements, NodeStats(0)); + int64 current_bucket_id = 0; + int64 last_bucket_id = -1; + // Indexes offsets for each of the partitions that can be used to access + // gradients of a partition for a current bucket we consider. + std::vector current_layer_offsets(num_elements, 0); + std::vector left_gradient_stats(num_elements); + // The idea is to try every bucket id in increasing order. In each iteration + // we calculate the gain of the layer using the current bucket id as split + // value, and we also obtain the following bucket id to try. + while (current_bucket_id > last_bucket_id) { + last_bucket_id = current_bucket_id; + int64 next_bucket_id = -1; + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + int idx = + current_layer_offsets[root_idx] + partition_boundaries[root_idx]; + const int end_index = partition_boundaries[root_idx + 1]; + if (idx < end_index && bucket_ids(idx, 0) == current_bucket_id) { + GradientStats g(*gradients_t, *hessians_t, idx); + g *= normalizer_ratio; + left_gradient_stats[root_idx] += g; + current_layer_offsets[root_idx]++; + idx++; + } + if (idx < end_index && + (bucket_ids(idx, 0) < next_bucket_id || next_bucket_id == -1)) { + next_bucket_id = bucket_ids(idx, 0); + } + } + float gain_of_split = 0.0; + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + GradientStats right_gradient_stats = + current_layer_stats[root_idx] - left_gradient_stats[root_idx]; + NodeStats left_stat = + state->ComputeNodeStats(left_gradient_stats[root_idx]); + NodeStats right_stat = state->ComputeNodeStats(right_gradient_stats); + gain_of_split += left_stat.gain + right_stat.gain; + current_left_node_stats[root_idx] = left_stat; + current_right_node_stats[root_idx] = right_stat; + } + if (gain_of_split > best_gain) { + best_gain = gain_of_split; + best_left_node_stats = current_left_node_stats; + best_right_node_stats = current_right_node_stats; + } + current_bucket_id = next_bucket_id; + } + + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + best_gain -= state->ComputeNodeStats(current_layer_stats[root_idx]).gain; + } + best_gain -= num_elements * state->tree_complexity_regularization(); + + ObliviousSplitInfo oblivious_split_info; + auto* oblivious_dense_split = oblivious_split_info.mutable_split_node() + ->mutable_dense_float_binary_split(); + oblivious_dense_split->set_feature_column(state->feature_column_group_id()); + oblivious_dense_split->set_threshold( + bucket_boundaries(bucket_ids(best_bucket_idx, 0))); + (*gains)(0) = best_gain; + + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + auto* left_children = oblivious_split_info.add_children_leaves(); + auto* right_children = oblivious_split_info.add_children_leaves(); + + state->FillLeaf(best_left_node_stats[root_idx], left_children); + state->FillLeaf(best_right_node_stats[root_idx], right_children); + + const int start_index = partition_boundaries[root_idx]; + (*output_partition_ids)(root_idx) = partition_ids(start_index); + } + oblivious_split_info.SerializeToString(&(*output_splits)(0)); + } }; REGISTER_KERNEL_BUILDER(Name("BuildDenseInequalitySplits").Device(DEVICE_CPU), BuildDenseInequalitySplitsOp); diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index 2559fe9913f..f45010ec26e 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -64,6 +64,7 @@ from __future__ import print_function import re from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler +from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.python.ops import gen_quantile_ops from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops from tensorflow.contrib.boosted_trees.python.ops import quantile_ops @@ -171,6 +172,7 @@ class DenseSplitHandler(InequalitySplitHandler): multiclass_strategy, init_stamp_token=0, loss_uses_sum_reduction=False, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE, name=None): """Initialize the internal state for this split handler. @@ -192,6 +194,7 @@ class DenseSplitHandler(InequalitySplitHandler): stamped objects. loss_uses_sum_reduction: A scalar boolean tensor that specifies whether SUM or MEAN reduction was used for the loss. + weak_learner_type: Specifies the type of weak learner to use. name: An optional handler name. """ super(DenseSplitHandler, self).__init__( @@ -209,6 +212,7 @@ class DenseSplitHandler(InequalitySplitHandler): multiclass_strategy=multiclass_strategy, loss_uses_sum_reduction=loss_uses_sum_reduction) self._dense_float_column = dense_float_column + self._weak_learner_type = weak_learner_type # Register dense_make_stats_update function as an Op to the graph. g = ops.get_default_graph() dense_make_stats_update.add_to_graph(g) @@ -269,16 +273,17 @@ class DenseSplitHandler(InequalitySplitHandler): next_stamp_token, self._multiclass_strategy, class_id, self._feature_column_group_id, self._l1_regularization, self._l2_regularization, self._tree_complexity_regularization, - self._min_node_weight, self._loss_uses_sum_reduction)) - + self._min_node_weight, self._loss_uses_sum_reduction, + self._weak_learner_type)) return are_splits_ready, partition_ids, gains, split_infos -def _make_dense_split( - quantile_accumulator_handle, stats_accumulator_handle, stamp_token, - next_stamp_token, multiclass_strategy, class_id, feature_column_id, - l1_regularization, l2_regularization, tree_complexity_regularization, - min_node_weight, is_multi_dimentional, loss_uses_sum_reduction): +def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle, + stamp_token, next_stamp_token, multiclass_strategy, + class_id, feature_column_id, l1_regularization, + l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional, + loss_uses_sum_reduction, weak_learner_type): """Function that builds splits for a dense feature column.""" # Get the bucket boundaries are_splits_ready, buckets = ( @@ -327,7 +332,8 @@ def _make_dense_split( l2_regularization=l2_regularization, tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, - multiclass_strategy=multiclass_strategy)) + multiclass_strategy=multiclass_strategy, + weak_learner_type=weak_learner_type)) return are_splits_ready, partition_ids, gains, split_infos @@ -507,7 +513,40 @@ def _make_sparse_split( return are_splits_ready, partition_ids, gains, split_infos -def _specialize_make_split(func, is_multi_dimentional): +def _specialize_make_split_dense(func, is_multi_dimentional): + """Builds a specialized version of the function.""" + + @function.Defun( + dtypes.resource, + dtypes.resource, + dtypes.int64, + dtypes.int64, + dtypes.int32, + dtypes.int32, + dtypes.int32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.bool, + dtypes.int32, + noinline=True) + def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token, + next_stamp_token, multiclass_strategy, class_id, feature_column_id, + l1_regularization, l2_regularization, tree_complexity_regularization, + min_node_weight, loss_uses_sum_reduction, weak_learner_type): + """Function that builds splits for a sparse feature column.""" + return func(quantile_accumulator_handle, stats_accumulator_handle, + stamp_token, next_stamp_token, multiclass_strategy, class_id, + feature_column_id, l1_regularization, l2_regularization, + tree_complexity_regularization, min_node_weight, + is_multi_dimentional, loss_uses_sum_reduction, + weak_learner_type) + + return f + + +def _specialize_make_split_sparse(func, is_multi_dimentional): """Builds a specialized version of the function.""" @function.Defun( @@ -537,15 +576,17 @@ def _specialize_make_split(func, is_multi_dimentional): return f -make_dense_split_scalar = _specialize_make_split(_make_dense_split, - is_multi_dimentional=False) -make_dense_split_tensor = _specialize_make_split(_make_dense_split, - is_multi_dimentional=True) -make_sparse_split_scalar = _specialize_make_split(_make_sparse_split, - is_multi_dimentional=False) -make_sparse_split_tensor = _specialize_make_split(_make_sparse_split, - is_multi_dimentional=True) +make_dense_split_scalar = _specialize_make_split_dense( + _make_dense_split, is_multi_dimentional=False) + +make_dense_split_tensor = _specialize_make_split_dense( + _make_dense_split, is_multi_dimentional=True) + +make_sparse_split_scalar = _specialize_make_split_sparse( + _make_sparse_split, is_multi_dimentional=False) +make_sparse_split_tensor = _specialize_make_split_sparse( + _make_sparse_split, is_multi_dimentional=True) @function.Defun( diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index 5d82c4cae5d..6572f2f414b 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -182,6 +182,133 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.threshold, 0.00001) + def testObliviousFeatureSplitGeneration(self): + with self.test_session() as sess: + # The data looks like the following: + # Example | Gradients | Partition | Dense Quantile | + # i0 | (0.2, 0.12) | 0 | 2 | + # i1 | (-0.5, 0.07) | 0 | 2 | + # i2 | (1.2, 0.2) | 0 | 0 | + # i3 | (4.0, 0.13) | 1 | 1 | + dense_column = array_ops.constant([0.62, 0.62, 0.3, 0.52]) + gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) + hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) + partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) + class_id = -1 + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + split_handler = ordinal_split_handler.DenseSplitHandler( + l1_regularization=0.1, + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., + epsilon=0.001, + num_quantiles=10, + feature_column_group_id=0, + dense_float_column=dense_column, + init_stamp_token=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE) + resources.initialize_resources(resources.shared_resources()).run() + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([4, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1]): + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + + with ops.control_dependencies([are_splits_ready]): + update_2 = split_handler.update_stats_sync( + 1, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_2]): + are_splits_ready2, partitions, gains, splits = ( + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) + are_splits_ready, are_splits_ready2, partitions, gains, splits = ( + sess.run([ + are_splits_ready, are_splits_ready2, partitions, gains, splits + ])) + + # During the first iteration, inequality split handlers are not going to + # have any splits. Make sure that we return not_ready in that case. + self.assertFalse(are_splits_ready) + self.assertTrue(are_splits_ready2) + + self.assertAllEqual([0, 1], partitions) + + oblivious_split_info = split_info_pb2.ObliviousSplitInfo() + oblivious_split_info.ParseFromString(splits[0]) + split_node = oblivious_split_info.split_node.dense_float_binary_split + + self.assertAllClose(0.3, split_node.threshold, 0.00001) + self.assertEqual(0, split_node.feature_column) + + # Check the split on partition 0. + # -(1.2 - 0.1) / (0.2 + 1) + expected_left_weight_0 = -0.9166666666666666 + + # expected_left_weight_0 * -(1.2 - 0.1) + expected_left_gain_0 = 1.008333333333333 + + # (-0.5 + 0.2 + 0.1) / (0.19 + 1) + expected_right_weight_0 = 0.1680672 + + # expected_right_weight_0 * -(-0.5 + 0.2 + 0.1)) + expected_right_gain_0 = 0.033613445378151252 + + # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1) + expected_bias_gain_0 = 0.46043165467625896 + + left_child = oblivious_split_info.children_leaves[0].vector + right_child = oblivious_split_info.children_leaves[1].vector + + self.assertAllClose([expected_left_weight_0], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight_0], right_child.value, 0.00001) + + # Check the split on partition 1. + expected_left_weight_1 = 0 + expected_left_gain_1 = 0 + # -(4 - 0.1) / (0.13 + 1) + expected_right_weight_1 = -3.4513274336283186 + # expected_right_weight_1 * -(4 - 0.1) + expected_right_gain_1 = 13.460176991150442 + # (-4 + 0.1) ** 2 / (0.13 + 1) + expected_bias_gain_1 = 13.460176991150442 + + left_child = oblivious_split_info.children_leaves[2].vector + right_child = oblivious_split_info.children_leaves[3].vector + + self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight_1], right_child.value, 0.00001) + + # The layer gain is the sum of the gains of each partition + layer_gain = ( + expected_left_gain_0 + expected_right_gain_0 - expected_bias_gain_0) + ( + expected_left_gain_1 + expected_right_gain_1 - expected_bias_gain_1) + self.assertAllClose(layer_gain, gains[0], 0.00001) + def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self): with self.test_session() as sess: # The data looks like the following: diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc index ca5c7f3d8c7..9b68a9de96e 100644 --- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc @@ -36,6 +36,7 @@ REGISTER_OP("BuildDenseInequalitySplits") .Input("tree_complexity_regularization: float") .Input("min_node_weight: float") .Input("multiclass_strategy: int32") + .Input("weak_learner_type: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -84,6 +85,8 @@ min_node_weight: A scalar, minimum sum of example hessian needed in a child. be considered. multiclass_strategy: A scalar, specifying the multiclass handling strategy. See LearnerConfig.MultiClassStrategy for valid values. +weak_learner_type: A scalar, specifying the weak learner type to use. + See LearnerConfig.WeakLearnerType for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. diff --git a/tensorflow/contrib/boosted_trees/proto/learner.proto b/tensorflow/contrib/boosted_trees/proto/learner.proto index d84ba7438e7..c49cb48cdea 100644 --- a/tensorflow/contrib/boosted_trees/proto/learner.proto +++ b/tensorflow/contrib/boosted_trees/proto/learner.proto @@ -108,6 +108,11 @@ message LearnerConfig { DIAGONAL_HESSIAN = 3; } + enum WeakLearnerType { + NORMAL_DECISION_TREE = 0; + OBLIVIOUS_DECISION_TREE = 1; + } + // Number of classes. uint32 num_classes = 1; @@ -141,4 +146,7 @@ message LearnerConfig { // If you want to average the ensembles (for regularization), provide the // config below. AveragingConfig averaging_config = 11; + + // By default we use NORMAL_DECISION_TREE as weak learner. + WeakLearnerType weak_learner_type = 12; } diff --git a/tensorflow/contrib/boosted_trees/proto/split_info.proto b/tensorflow/contrib/boosted_trees/proto/split_info.proto index a300c24c8ec..850340f5c20 100644 --- a/tensorflow/contrib/boosted_trees/proto/split_info.proto +++ b/tensorflow/contrib/boosted_trees/proto/split_info.proto @@ -17,3 +17,10 @@ message SplitInfo { // Right Leaf node. tensorflow.boosted_trees.trees.Leaf right_child = 3; } + +message ObliviousSplitInfo { + // The split node with the feature_column and threshold defined. + tensorflow.boosted_trees.trees.TreeNode split_node = 1; + // The new leaves of the tree. + repeated tensorflow.boosted_trees.trees.Leaf children_leaves = 2; +} diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py index 5cd37ec67ec..25895047627 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py @@ -59,7 +59,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): min_node_weight=0, class_id=-1, feature_column_group_id=0, - multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)) partitions, gains, splits = sess.run([partitions, gains, splits]) self.assertAllEqual([0, 1], partitions) @@ -132,7 +133,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): min_node_weight=0, class_id=-1, feature_column_group_id=0, - multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN)) + multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)) partitions, gains, splits = sess.run([partitions, gains, splits]) self.assertAllEqual([0, 1], partitions) @@ -171,7 +173,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): min_node_weight=0, class_id=-1, feature_column_group_id=0, - multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)) partitions, gains, splits = sess.run([partitions, gains, splits]) # .assertEmpty doesn't exist on ubuntu-contrib self.assertEqual(0, len(partitions)) diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index d0d1249bd6a..20ff48c3602 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -672,6 +672,8 @@ class GradientBoostedDecisionTreeModel(object): self._learner_config.constraints.min_node_weight, dtypes.float32) loss_uses_sum_reduction = self._loss_reduction == losses.Reduction.SUM loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction) + weak_learner_type = constant_op.constant( + self._learner_config.weak_learner_type) epsilon = 0.01 num_quantiles = 100 strategy_tensor = constant_op.constant(strategy) @@ -696,6 +698,7 @@ class GradientBoostedDecisionTreeModel(object): multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token, loss_uses_sum_reduction=loss_uses_sum_reduction, + weak_learner_type=weak_learner_type, )) fc_name_idx += 1 diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 2fbaa31d5e1..e92f0bb841a 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -31,6 +31,9 @@ Checkpointable data structures: @@List @@Mapping @@UniqueNameTracker + +Checkpoint management: +@@CheckpointManager """ from __future__ import absolute_import @@ -41,6 +44,7 @@ from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph +from tensorflow.python.training.checkpoint_management import CheckpointManager from tensorflow.python.training.checkpointable.base import CheckpointableBase from tensorflow.python.training.checkpointable.data_structures import List from tensorflow.python.training.checkpointable.data_structures import Mapping diff --git a/tensorflow/contrib/constrained_optimization/python/candidates.py b/tensorflow/contrib/constrained_optimization/python/candidates.py index ac86a6741be..66d7ebed74d 100644 --- a/tensorflow/contrib/constrained_optimization/python/candidates.py +++ b/tensorflow/contrib/constrained_optimization/python/candidates.py @@ -204,7 +204,7 @@ def find_best_candidate_distribution(objective_vector, assert best_pp is not None # Throughout this loop, a maximum_violation of "lower" is not achievable, - # but a maximum_violation of "upper" is achiveable. + # but a maximum_violation of "upper" is achievable. while True: middle = 0.5 * (lower + upper) if (middle - lower <= epsilon) or (upper - middle <= epsilon): diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py index 70813fb2179..41258edd908 100644 --- a/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py +++ b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py @@ -72,7 +72,8 @@ class ConstrainedMinimizationProblem(object): else: proxy_constraints_shape = self.proxy_constraints.get_shape() - if (constraints_shape is None or proxy_constraints_shape is None or + if (constraints_shape.ndims is None or + proxy_constraints_shape.ndims is None or any([ii is None for ii in constraints_shape.as_list()]) or any([ii is None for ii in proxy_constraints_shape.as_list()])): raise ValueError( @@ -121,3 +122,19 @@ class ConstrainedMinimizationProblem(object): A tensor of proxy constraint functions. """ return None + + # This is a property, instead of an abstract property, since it doesn't need + # to be overridden: if pre_train_ops returns None, then there are no ops to + # run before train_op. + @property + def pre_train_ops(self): + """Returns a list of `Operation`s to run before the train_op. + + When a `ConstrainedOptimizer` creates a train_op (in `minimize` + `minimize_unconstrained`, or `minimize_constrained`), it will include these + ops before the main training step. + + Returns: + A list of `Operation`s. + """ + return None diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py b/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py index 80555453661..0b79bdf7c05 100644 --- a/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py +++ b/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py @@ -55,20 +55,21 @@ class ConstrainedOptimizer(object): """Returns the `tf.train.Optimizer` used for optimization.""" return self._optimizer - def minimize_unconstrained(self, - minimization_problem, - global_step=None, - var_list=None, - gate_gradients=train_optimizer.Optimizer.GATE_OP, - aggregation_method=None, - colocate_gradients_with_ops=False, - name=None, - grad_loss=None): - """Returns an `Op` for minimizing the unconstrained problem. + @abc.abstractmethod + def _minimize_constrained(self, + minimization_problem, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Version of `minimize_constrained` to be overridden by subclasses. - Unlike `minimize_constrained`, this function ignores the `constraints` (and - `proxy_constraints`) portion of the minimization problem entirely, and only - minimizes `objective`. + Implementations of this method should ignore the `pre_train_ops` property of + the `minimization_problem`. The public `minimize_constrained` method will + take care of executing these before the returned train_op. Args: minimization_problem: ConstrainedMinimizationProblem, the problem to @@ -83,19 +84,10 @@ class ConstrainedOptimizer(object): grad_loss: as in `tf.train.Optimizer`'s `minimize` method. Returns: - TensorFlow Op. + `Operation`, the train_op. """ - return self.optimizer.minimize( - minimization_problem.objective, - global_step=global_step, - var_list=var_list, - gate_gradients=gate_gradients, - aggregation_method=aggregation_method, - colocate_gradients_with_ops=colocate_gradients_with_ops, - name=name, - grad_loss=grad_loss) + pass - @abc.abstractmethod def minimize_constrained(self, minimization_problem, global_step=None, @@ -105,7 +97,7 @@ class ConstrainedOptimizer(object): colocate_gradients_with_ops=False, name=None, grad_loss=None): - """Returns an `Op` for minimizing the constrained problem. + """Returns an `Operation` for minimizing the constrained problem. Unlike `minimize_unconstrained`, this function attempts to find a solution that minimizes the `objective` portion of the minimization problem while @@ -124,9 +116,83 @@ class ConstrainedOptimizer(object): grad_loss: as in `tf.train.Optimizer`'s `minimize` method. Returns: - TensorFlow Op. + `Operation`, the train_op. """ - pass + + def train_op_callback(): + return self._minimize_constrained( + minimization_problem, + global_step=global_step, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + name=name, + grad_loss=grad_loss) + + # If we have pre_train_ops, use tf.control_dependencies() to ensure that + # they execute before the train_op. + pre_train_ops = minimization_problem.pre_train_ops + if pre_train_ops: + with ops.control_dependencies(pre_train_ops): + train_op = train_op_callback() + else: + train_op = train_op_callback() + + return train_op + + def minimize_unconstrained(self, + minimization_problem, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Returns an `Operation` for minimizing the unconstrained problem. + + Unlike `minimize_constrained`, this function ignores the `constraints` (and + `proxy_constraints`) portion of the minimization problem entirely, and only + minimizes `objective`. + + Args: + minimization_problem: ConstrainedMinimizationProblem, the problem to + optimize. + global_step: as in `tf.train.Optimizer`'s `minimize` method. + var_list: as in `tf.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + method. + name: as in `tf.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + + Returns: + `Operation`, the train_op. + """ + + def train_op_callback(): + return self.optimizer.minimize( + minimization_problem.objective, + global_step=global_step, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + name=name, + grad_loss=grad_loss) + + # If we have pre_train_ops, use tf.control_dependencies() to ensure that + # they execute before the train_op. + pre_train_ops = minimization_problem.pre_train_ops + if pre_train_ops: + with ops.control_dependencies(pre_train_ops): + train_op = train_op_callback() + else: + train_op = train_op_callback() + + return train_op def minimize(self, minimization_problem, @@ -138,7 +204,7 @@ class ConstrainedOptimizer(object): colocate_gradients_with_ops=False, name=None, grad_loss=None): - """Returns an `Op` for minimizing the constrained problem. + """Returns an `Operation` for minimizing the constrained problem. This method combines the functionality of `minimize_unconstrained` and `minimize_constrained`. If global_step < unconstrained_steps, it will @@ -164,14 +230,14 @@ class ConstrainedOptimizer(object): grad_loss: as in `tf.train.Optimizer`'s `minimize` method. Returns: - TensorFlow Op. + `Operation`, the train_op. Raises: ValueError: If unconstrained_steps is provided, but global_step is not. """ def unconstrained_fn(): - """Returns an `Op` for minimizing the unconstrained problem.""" + """Returns an `Operation` for minimizing the unconstrained problem.""" return self.minimize_unconstrained( minimization_problem=minimization_problem, global_step=global_step, @@ -183,7 +249,7 @@ class ConstrainedOptimizer(object): grad_loss=grad_loss) def constrained_fn(): - """Returns an `Op` for minimizing the constrained problem.""" + """Returns an `Operation` for minimizing the constrained problem.""" return self.minimize_constrained( minimization_problem=minimization_problem, global_step=global_step, diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py index 01c6e4f08af..d1af15f7e42 100644 --- a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py +++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py @@ -70,11 +70,13 @@ def _project_multipliers_wrt_euclidean_norm(multipliers, radius): region w.r.t. the Euclidean norm. Raises: - ValueError: if the `multipliers` tensor does not have a fully-known shape, - or is not one-dimensional. + ValueError: if the `multipliers` tensor is not floating-point, does not have + a fully-known shape, or is not one-dimensional. """ + if not multipliers.dtype.is_floating: + raise ValueError("multipliers must have a floating-point dtype") multipliers_shape = multipliers.get_shape() - if multipliers_shape is None: + if multipliers_shape.ndims is None: raise ValueError("multipliers must have known shape") if multipliers_shape.ndims != 1: raise ValueError( @@ -101,12 +103,12 @@ def _project_multipliers_wrt_euclidean_norm(multipliers, radius): (radius - standard_ops.reduce_sum(multipliers)) / standard_ops.maximum( 1.0, standard_ops.reduce_sum(inactive))) multipliers += scale * inactive - new_inactive = standard_ops.to_float(multipliers > 0) + new_inactive = standard_ops.cast(multipliers > 0, multipliers.dtype) multipliers *= new_inactive return (iteration, multipliers, new_inactive, inactive) iteration = standard_ops.constant(0) - inactive = standard_ops.ones_like(multipliers) + inactive = standard_ops.ones_like(multipliers, dtype=multipliers.dtype) # We actually want a do-while loop, so we explicitly call while_loop_body() # once before tf.while_loop(). @@ -189,16 +191,16 @@ class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): def _projection_op(self, state, name=None): pass - def minimize_constrained(self, - minimization_problem, - global_step=None, - var_list=None, - gate_gradients=train_optimizer.Optimizer.GATE_OP, - aggregation_method=None, - colocate_gradients_with_ops=False, - name=None, - grad_loss=None): - """Returns an `Op` for minimizing the constrained problem. + def _minimize_constrained(self, + minimization_problem, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Returns an `Operation` for minimizing the constrained problem. The `optimizer` constructor parameter will be used to update the model parameters, while the Lagrange multipliers will be updated using @@ -216,8 +218,11 @@ class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): name: as in `tf.train.Optimizer`'s `minimize` method. grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + Raises: + ValueError: If the minimization_problem tensors have different dtypes. + Returns: - TensorFlow Op. + `Operation`, the train_op. """ objective = minimization_problem.objective @@ -225,6 +230,14 @@ class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): proxy_constraints = minimization_problem.proxy_constraints if proxy_constraints is None: proxy_constraints = constraints + + # Make sure that the objective, constraints and proxy constraints all have + # the same dtype. + if (objective.dtype.base_dtype != constraints.dtype.base_dtype or + objective.dtype.base_dtype != proxy_constraints.dtype.base_dtype): + raise ValueError("objective, constraints and proxy_constraints must " + "have the same dtype") + # Flatten both constraints tensors to 1d. num_constraints = minimization_problem.num_constraints constraints = standard_ops.reshape(constraints, shape=(num_constraints,)) @@ -241,8 +254,10 @@ class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): multipliers = self._lagrange_multipliers(state) loss = ( - objective + standard_ops.tensordot(multipliers, proxy_constraints, 1)) - multipliers_gradient = constraints + objective + standard_ops.tensordot( + standard_ops.cast(multipliers, proxy_constraints.dtype), + proxy_constraints, 1)) + multipliers_gradient = standard_ops.cast(constraints, multipliers.dtype) update_ops = [] if self.constraint_optimizer is None: @@ -356,6 +371,8 @@ class AdditiveExternalRegretOptimizer(_ExternalRegretOptimizer): # For an AdditiveExternalRegretOptimizer, the internal state is simply a # tensor of Lagrange multipliers with shape (m,), where m is the number of # constraints. + # + # FUTURE WORK: make the dtype a parameter. return standard_ops.zeros((num_constraints,), dtype=dtypes.float32) def _lagrange_multipliers(self, state): diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py index ff846b191a3..2c673d93471 100644 --- a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py +++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py @@ -79,9 +79,11 @@ def _maximal_eigenvector_power_method(matrix, The maximal right-eigenvector of `matrix`. Raises: - ValueError: If the epsilon or maximum_iterations parameters violate their - bounds. + ValueError: If the `matrix` tensor is not floating-point, or if the + `epsilon` or `maximum_iterations` parameters violate their bounds. """ + if not matrix.dtype.is_floating: + raise ValueError("multipliers must have a floating-point dtype") if epsilon <= 0.0: raise ValueError("epsilon must be strictly positive") if maximum_iterations <= 0: @@ -139,11 +141,13 @@ def _project_stochastic_matrix_wrt_euclidean_norm(matrix): (i.e. the Frobenius norm). Raises: - ValueError: if the `matrix` tensor does not have a fully-known shape, or is - not two-dimensional and square. + ValueError: if the `matrix` tensor is not floating-point, does not have a + fully-known shape, or is not two-dimensional and square. """ + if not matrix.dtype.is_floating: + raise ValueError("multipliers must have a floating-point dtype") matrix_shape = matrix.get_shape() - if matrix_shape is None: + if matrix_shape.ndims is None: raise ValueError("matrix must have known shape") if matrix_shape.ndims != 2: raise ValueError( @@ -172,12 +176,12 @@ def _project_stochastic_matrix_wrt_euclidean_norm(matrix): matrix, axis=0, keepdims=True)) / standard_ops.maximum( 1.0, standard_ops.reduce_sum(inactive, axis=0, keepdims=True)) matrix += scale * inactive - new_inactive = standard_ops.to_float(matrix > 0) + new_inactive = standard_ops.cast(matrix > 0, matrix.dtype) matrix *= new_inactive return (iteration, matrix, new_inactive, inactive) iteration = standard_ops.constant(0) - inactive = standard_ops.ones_like(matrix) + inactive = standard_ops.ones_like(matrix, dtype=matrix.dtype) # We actually want a do-while loop, so we explicitly call while_loop_body() # once before tf.while_loop(). @@ -218,7 +222,7 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): """Base class representing a `_SwapRegretOptimizer`. This class contains most of the logic for performing constrained optimization, - minimizing external regret for the constraints player. What it *doesn't* do is + minimizing swap regret for the constraints player. What it *doesn't* do is keep track of the internal state (the stochastic matrix). Instead, the state is accessed via the _initial_state(), _stochastic_matrix(), _constraint_grad_and_var() and _projection_op() methods. @@ -291,16 +295,16 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): def _projection_op(self, state, name=None): pass - def minimize_constrained(self, - minimization_problem, - global_step=None, - var_list=None, - gate_gradients=train_optimizer.Optimizer.GATE_OP, - aggregation_method=None, - colocate_gradients_with_ops=False, - name=None, - grad_loss=None): - """Returns an `Op` for minimizing the constrained problem. + def _minimize_constrained(self, + minimization_problem, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Returns an `Operation` for minimizing the constrained problem. The `optimizer` constructor parameter will be used to update the model parameters, while the constraint/objective weight matrix (the analogue of @@ -320,8 +324,11 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): name: as in `tf.train.Optimizer`'s `minimize` method. grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + Raises: + ValueError: If the minimization_problem tensors have different dtypes. + Returns: - TensorFlow Op. + `Operation`, the train_op. """ objective = minimization_problem.objective @@ -329,6 +336,14 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): proxy_constraints = minimization_problem.proxy_constraints if proxy_constraints is None: proxy_constraints = constraints + + # Make sure that the objective, constraints and proxy constraints all have + # the same dtype. + if (objective.dtype.base_dtype != constraints.dtype.base_dtype or + objective.dtype.base_dtype != proxy_constraints.dtype.base_dtype): + raise ValueError("objective, constraints and proxy_constraints must " + "have the same dtype") + # Flatten both constraints tensors to 1d. num_constraints = minimization_problem.num_constraints constraints = standard_ops.reshape(constraints, shape=(num_constraints,)) @@ -344,15 +359,18 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): name="swap_regret_optimizer_state") zero_and_constraints = standard_ops.concat( - (standard_ops.zeros((1,)), constraints), axis=0) + (standard_ops.zeros((1,), dtype=constraints.dtype), constraints), + axis=0) objective_and_proxy_constraints = standard_ops.concat( (standard_ops.expand_dims(objective, 0), proxy_constraints), axis=0) distribution = self._distribution(state) - loss = standard_ops.tensordot(distribution, objective_and_proxy_constraints, - 1) + loss = standard_ops.tensordot( + standard_ops.cast(distribution, objective_and_proxy_constraints.dtype), + objective_and_proxy_constraints, 1) matrix_gradient = standard_ops.matmul( - standard_ops.expand_dims(zero_and_constraints, 1), + standard_ops.expand_dims( + standard_ops.cast(zero_and_constraints, distribution.dtype), 1), standard_ops.expand_dims(distribution, 0)) update_ops = [] @@ -555,6 +573,7 @@ class MultiplicativeSwapRegretOptimizer(_SwapRegretOptimizer): log_initial_one = math.log(1.0 - (self._initial_multiplier_radius * (dimension - 1) / (dimension))) log_initial_zero = math.log(self._initial_multiplier_radius / dimension) + # FUTURE WORK: make the dtype a parameter. return standard_ops.concat( (standard_ops.constant( log_initial_one, dtype=dtypes.float32, shape=(1, dimension)), diff --git a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc index bff63012501..e36c9c06342 100644 --- a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc @@ -42,13 +42,13 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input, const std::vector& transformations, const DataTypeVector& output_types, const std::vector& output_shapes) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), input_(input), transformations_(transformations), output_types_(output_types), diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index 51e1b9aa656..d242cfdf491 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -131,7 +131,7 @@ class CSVDatasetOp : public DatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, std::vector filenames, bool header, string compression_type, io::ZlibCompressionOptions options, @@ -139,7 +139,7 @@ class CSVDatasetOp : public DatasetOpKernel { const std::vector& output_shapes, std::vector record_defaults, std::vector select_cols, bool use_quote_delim, char delim, string na_value) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), filenames_(std::move(filenames)), header_(header), out_type_(output_types), diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc index b9306f611b8..ccf7ec1f842 100644 --- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc @@ -63,11 +63,11 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* selector_input, std::vector data_inputs) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), selector_input_(selector_input), data_inputs_(std::move(data_inputs)) { selector_input_->Ref(); diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc index d77beb8e105..db24e608463 100644 --- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc @@ -35,10 +35,10 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext* ctx, const DatasetBase* input) - : GraphDatasetBase(ctx), input_(input) { + : DatasetBase(DatasetContext(ctx)), input_(input) { input_->Ref(); } diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index 13bcd77b4af..74df1e42a8f 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -929,10 +929,9 @@ class MultiDeviceIteratorInitOp : public OpKernel { LookupResource(ctx, HandleFromInput(ctx, 1), &resource)); core::ScopedUnref unref(resource); - IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx); std::unique_ptr iterator; - OP_REQUIRES_OK(ctx, - dataset->MakeIterator(&iter_ctx, "Iterator", &iterator)); + OP_REQUIRES_OK(ctx, dataset->MakeIterator(IteratorContext(ctx), "Iterator", + &iterator)); int64 incarnation_id; OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size, &incarnation_id)); diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc index 4dc69dc2efa..ab584504a05 100644 --- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc @@ -130,11 +130,13 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input, ThreadPoolResource* threadpool) - : GraphDatasetBase(ctx), input_(input), threadpool_(threadpool) { + : DatasetBase(DatasetContext(ctx)), + input_(input), + threadpool_(threadpool) { input_->Ref(); threadpool_->Ref(); } @@ -165,9 +167,8 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - return errors::Unimplemented( - "Cannot currently serialize the thread pool for a " - "ThreadPoolDataset."); + return errors::Unimplemented("%s does not support serialization", + DebugString()); } private: diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/contrib/data/kernels/unique_dataset_op.cc index f6bfc982e93..6fbf5d2ebb5 100644 --- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/unique_dataset_op.cc @@ -47,10 +47,10 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input) - : GraphDatasetBase(ctx), input_(input) { + : DatasetBase(DatasetContext(ctx)), input_(input) { input_->Ref(); } diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py index 9123ca749b6..2c93ce92ceb 100644 --- a/tensorflow/contrib/distribute/__init__.py +++ b/tensorflow/contrib/distribute/__init__.py @@ -22,13 +22,14 @@ from __future__ import print_function from tensorflow.contrib.distribute.python.collective_all_reduce_strategy import CollectiveAllReduceStrategy from tensorflow.contrib.distribute.python.cross_tower_ops import * from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy -from tensorflow.contrib.distribute.python.multi_worker_strategy import MultiWorkerMirroredStrategy from tensorflow.contrib.distribute.python.monitor import Monitor +from tensorflow.contrib.distribute.python.multi_worker_strategy import MultiWorkerMirroredStrategy from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy from tensorflow.contrib.distribute.python.step_fn import * from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy from tensorflow.python.training.distribute import * +from tensorflow.python.training.distribution_strategy_context import * from tensorflow.python.util.all_util import remove_undocumented @@ -55,6 +56,7 @@ _allowed_symbols = [ 'get_tower_context', 'has_distribution_strategy', 'require_tower_context', + 'UpdateContext', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index a1efbcaf9ac..aeec9c44d72 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -56,7 +56,7 @@ from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.training import adam -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import gradient_descent from tensorflow.python.util import tf_inspect @@ -320,7 +320,7 @@ class NamedDistribution(object): # pylint: disable=g-long-lambda default_strategy = NamedDistribution( "Default", - lambda: distribute_lib._default_distribution_strategy, # pylint: disable=protected-access + distribution_strategy_context._get_default_distribution_strategy, # pylint: disable=protected-access required_gpus=None) one_device_strategy = NamedDistribution( "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py index 3e00cf4332d..cc626c33bf8 100644 --- a/tensorflow/contrib/distribute/python/estimator_integration_test.py +++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py @@ -29,6 +29,7 @@ from tensorflow.contrib.optimizer_v2 import adagrad from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import test from tensorflow.python.estimator import run_config +from tensorflow.python.estimator import training from tensorflow.python.estimator.canned import dnn_linear_combined from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export @@ -63,8 +64,9 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus - ])) - def test_complete_flow_with_mode(self, distribution): + ], + use_train_and_evaluate=[True, False])) + def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate): label_dimension = 2 input_dimension = label_dimension batch_size = 10 @@ -103,9 +105,15 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, train_distribute=distribution, eval_distribute=distribution)) num_steps = 10 - estimator.train(train_input_fn, steps=num_steps) + if use_train_and_evaluate: + scores, _ = training.train_and_evaluate( + estimator, + training.TrainSpec(train_input_fn, max_steps=num_steps), + training.EvalSpec(eval_input_fn)) + else: + estimator.train(train_input_fn, steps=num_steps) + scores = estimator.evaluate(eval_input_fn) - scores = estimator.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) self.assertIn('loss', six.iterkeys(scores)) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index e064cfe37db..9a4cc0a8975 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -40,7 +40,7 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import device_util -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context GPU_TEST = "test_gpu" in sys.argv[0] @@ -164,7 +164,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): # This variable should be created only once across the threads because of # special variable_creator functions used by `dist.call_for_each_tower`. v = variable_scope.variable(1.0, name="foo") - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v dist = mirrored_strategy.MirroredStrategy( @@ -181,7 +181,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): v = variable_scope.variable(1.0) - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v dist = mirrored_strategy.MirroredStrategy( @@ -201,7 +201,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): vs = [] for i in range(5): vs.append(variable_scope.variable(1.0, name="foo" + str(i))) - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return vs dist = mirrored_strategy.MirroredStrategy( @@ -223,7 +223,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): vs.append(variable_scope.variable(1.0, name="foo_1/bar")) vs.append(variable_scope.variable(1.0, name="foo_1/bar_1")) vs.append(variable_scope.variable(1.0, name="foo/bar_1")) - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return vs dist = mirrored_strategy.MirroredStrategy( @@ -245,7 +245,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(device_id): v = variable_scope.variable(1.0, name="foo_" + str(device_id)) - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v dist = mirrored_strategy.MirroredStrategy( @@ -268,7 +268,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): layer2 = core.Dense(1) layer2(features) # This will pause the current thread, and execute the other thread. - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) layer3 = core.Dense(1) layer3(features) return [(layer1.kernel, layer1.bias), @@ -300,7 +301,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with variable_scope.variable_scope("common"): v1 = variable_scope.variable(1.0, name="var1") # This will pause the current thread, and execute the other thread. - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) v2 = variable_scope.variable( 1.0, name="var2", @@ -343,7 +345,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with variable_scope.variable_scope("common"): v1 = variable_scope.get_variable("var1", [1]) # This will pause the current thread, and execute the other thread. - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) v2 = variable_scope.get_variable( "var2", [1], synchronization=variable_scope.VariableSynchronization.ON_READ, @@ -453,7 +456,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): v = variable_scope.variable(1.0, name="foo") - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v dist = mirrored_strategy.MirroredStrategy( @@ -470,7 +473,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(name): v = variable_scope.variable(1.0, name=name) - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v dist = mirrored_strategy.MirroredStrategy( @@ -570,7 +573,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): with ops.name_scope("foo"): a = constant_op.constant(1.0, name="a") - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) b = constant_op.constant(1.0, name="b") return a, b @@ -591,7 +595,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): with ops.name_scope(None, "foo"): a = constant_op.constant(1.0, name="a") - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) b = constant_op.constant(2.0, name="b") return a, b @@ -619,7 +624,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): b = variable_scope.variable(1.0, name="b") with ops.name_scope("foo"): - c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + c = distribution_strategy_context.get_tower_context().merge_call( + in_cross_tower) return b, c dist = mirrored_strategy.MirroredStrategy( @@ -651,7 +657,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): b = variable_scope.get_variable("b", [1]) with ops.name_scope("foo"): - c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + c = distribution_strategy_context.get_tower_context().merge_call( + in_cross_tower) return b, c dist = mirrored_strategy.MirroredStrategy( @@ -833,8 +840,9 @@ class MirroredVariableUpdateTest(test.TestCase): self.assertEquals(1.0, self.evaluate(mirrored_var)) def model_fn(): - value = math_ops.cast(distribute_lib.get_tower_context().tower_id, - mirrored_var.dtype) + value = math_ops.cast( + distribution_strategy_context.get_tower_context().tower_id, + mirrored_var.dtype) return mirrored_var.assign(value) self.evaluate(dist.unwrap(dist.call_for_each_tower( @@ -898,8 +906,9 @@ class MirroredVariableUpdateTest(test.TestCase): self.assertEquals(1.0, self.evaluate(mirrored_var)) def model_fn(): - value = math_ops.cast(distribute_lib.get_tower_context().tower_id, - mirrored_var.dtype) + value = math_ops.cast( + distribution_strategy_context.get_tower_context().tower_id, + mirrored_var.dtype) return mirrored_var.assign_add(value) self.evaluate(dist.unwrap(dist.call_for_each_tower( @@ -963,8 +972,9 @@ class MirroredVariableUpdateTest(test.TestCase): self.assertEquals(5.0, self.evaluate(mirrored_var)) def model_fn(): - value = math_ops.cast(distribute_lib.get_tower_context().tower_id, - mirrored_var.dtype) + value = math_ops.cast( + distribution_strategy_context.get_tower_context().tower_id, + mirrored_var.dtype) return mirrored_var.assign_sub(value) self.evaluate(dist.unwrap(dist.call_for_each_tower( diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py index a066adf1246..5db2fff2390 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py @@ -24,7 +24,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import test_util from tensorflow.python.ops import variable_scope -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): @@ -68,7 +68,8 @@ class VariableCreatorStackTest(test.TestCase): v = variable_scope.variable(1.0) # This will pause the current thread, and execute the other thread. - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) return v def main_thread_creator(next_creator, *args, **kwargs): diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index cf29c0ed91a..02eb68227df 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -37,7 +37,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import device_util -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, @@ -101,7 +101,8 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, last_part_device = 'device:CPU:0' else: last_part_device = ( - 'device:GPU:%d' % distribute_lib.get_tower_context().tower_id) + 'device:GPU:%d' % + distribution_strategy_context.get_tower_context().tower_id) a = constant_op.constant(1.0) b = constant_op.constant(2.0) @@ -192,14 +193,16 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, tower_compute_device = '/device:CPU:0' else: tower_compute_device = ( - '/device:GPU:%d' % distribute_lib.get_tower_context().tower_id) + '/device:GPU:%d' % + distribution_strategy_context.get_tower_context().tower_id) tower_compute_device = device_util.canonicalize(tower_compute_device) if 'CPU' in variable_device: tower_variable_device = '/device:CPU:0' else: tower_variable_device = ( - '/device:GPU:%d' % distribute_lib.get_tower_context().tower_id) + '/device:GPU:%d' % + distribution_strategy_context.get_tower_context().tower_id) tower_variable_device = device_util.canonicalize(tower_variable_device) a = constant_op.constant(1.0) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index baed0ebaae8..371b97ba96a 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -28,7 +28,7 @@ from tensorflow.python.layers import core from tensorflow.python.ops import array_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import optimizer @@ -45,7 +45,8 @@ def _raise_exception_fn(_=None): # Must be the argument to a distribution.call_for_each_tower() call, calls a # get_tower_context().merge_call() that raises an exception. def _merge_raises_fn(): - distribute_lib.get_tower_context().merge_call(_raise_exception_fn) + distribution_strategy_context.get_tower_context().merge_call( + _raise_exception_fn) # Must be the argument to a get_tower_context().merge_call() call, calls @@ -58,7 +59,7 @@ def _call_raises_fn(dist): # calls a get_tower_context().merge_call() that calls a # call_for_each_tower() that raises an exception. def _merge_call_raises_fn(): - distribute_lib.get_tower_context().merge_call(_call_raises_fn) + distribution_strategy_context.get_tower_context().merge_call(_call_raises_fn) # Must be the argument to a get_tower_context().merge_call() call, calls @@ -72,7 +73,8 @@ def _call_merge_raises_fn(dist): # get_tower_context().merge_call() that calls a call_for_each_tower() that # calls a get_tower_context().merge_call() that raises an exception. def _merge_call_merge_raises_fn(): - distribute_lib.get_tower_context().merge_call(_call_merge_raises_fn) + distribution_strategy_context.get_tower_context().merge_call( + _call_merge_raises_fn) class DistributionTestBase(test.TestCase): @@ -208,7 +210,7 @@ class DistributionTestBase(test.TestCase): expected_devices = [False] * len(d.worker_devices) def mark_devices_fn(): - tower_id = distribute_lib.get_tower_context().tower_id + tower_id = distribution_strategy_context.get_tower_context().tower_id self.assertLess(tower_id, len(d.worker_devices)) self.assertFalse(expected_devices[tower_id]) expected_devices[tower_id] = True diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 5fd4c9de696..8548a864210 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -38,6 +38,7 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import saver from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util import nest @@ -56,7 +57,7 @@ class DistributedValues(object): def get(self, device=None): """Returns the value for the current device or raises a ValueError.""" if device is None: - tower_context = distribute_lib.get_tower_context() + tower_context = distribution_strategy_context.get_tower_context() if tower_context: device = tower_context.device else: @@ -289,14 +290,15 @@ class DistributedVariable(DistributedDelegate): # We want cross-tower code that does some var.op.X calls # to work (even if the current device isn't in self.devices), but # other uses of var.op in a cross-tower context to fail. - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): return DistributedVarOp(self._primary_var.op.name, self._primary_var.op.graph, self._primary_var.op.type) return self.get().op def read_value(self): - return distribute_lib.get_distribution_strategy().read_var(self) + return distribution_strategy_context.get_distribution_strategy().read_var( + self) def _should_act_as_resource_variable(self): """Pass resource_variable_ops.is_resource_variable check.""" @@ -362,7 +364,7 @@ class MirroredVariable(DistributedVariable, Mirrored, # update several non-slot variables in one call. def _assign_func(self, *args, **kwargs): f = kwargs.pop("f") - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): update_device = distribute_lib.get_update_device() # We are calling update on the mirrored variable in cross tower context. if update_device is not None: @@ -371,7 +373,7 @@ class MirroredVariable(DistributedVariable, Mirrored, v = self.get(device=update_device) return f(v, *args, **kwargs) - return distribute_lib.get_distribution_strategy().update( + return distribution_strategy_context.get_distribution_strategy().update( self, f, *args, **kwargs) else: _assert_tower_context() @@ -392,8 +394,8 @@ class MirroredVariable(DistributedVariable, Mirrored, aggregation=self._aggregation, value=value, destinations=self), *other_args, **other_kwargs) - return distribute_lib.get_tower_context().merge_call(merge_fn, *args, - **kwargs) + return distribution_strategy_context.get_tower_context().merge_call( + merge_fn, *args, **kwargs) def assign_sub(self, *args, **kwargs): assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) @@ -419,7 +421,7 @@ class MirroredVariable(DistributedVariable, Mirrored, def _as_graph_element(self): # pylint: disable=protected-access - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): return self._primary_var._as_graph_element() return self.get()._as_graph_element() @@ -459,7 +461,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): # We use a callable so that we don't have to evaluate this expression # in the case where we are trying to restore instead of save. def tensor(): - return distribute_lib.get_distribution_strategy().read_var( + return distribution_strategy_context.get_distribution_strategy().read_var( tower_local_variable) spec = saver.BaseSaverBuilder.SaveSpec( tensor=tensor, @@ -475,7 +477,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): def _assert_tower_context(): - if not distribute_lib.get_tower_context(): + if not distribution_strategy_context.get_tower_context(): raise RuntimeError( "Tower-local variables may only be assigned in a tower context.") @@ -498,7 +500,7 @@ class TowerLocalVariable(DistributedVariable, PerDevice, return self.get().assign_add(*args, **kwargs) def assign(self, *args, **kwargs): - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): # To preserve the sum across save and restore, we have to divide the # total across all devices when restoring a variable that was summed # when saving. @@ -526,7 +528,7 @@ class TowerLocalVariable(DistributedVariable, PerDevice, def _as_graph_element(self): # pylint: disable=protected-access - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): return self._get_cross_tower() return self.get()._as_graph_element() @@ -994,12 +996,12 @@ class MultiStepContext(object): outputs as already reduced or not. """ - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): self._last_step_outputs_aggregations[name] = aggregation if aggregation is variables_lib.VariableAggregation.NONE: self._last_step_outputs[name] = output else: - distribution = distribute_lib.get_distribution_strategy() + distribution = distribution_strategy_context.get_distribution_strategy() self._last_step_outputs[name] = distribution.reduce( aggregation, output, destinations="/device:CPU:0") else: @@ -1011,7 +1013,9 @@ class MultiStepContext(object): # context object, so it's more robust to set it only once (even if all # the towers are trying to set the same value). self._last_step_outputs_aggregations[name] = aggregation - distribute_lib.get_tower_context().merge_call(merge_fn, output) + + distribution_strategy_context.get_tower_context().merge_call( + merge_fn, output) @property def non_tensor_outputs(self): @@ -1020,14 +1024,15 @@ class MultiStepContext(object): def set_non_tensor_output(self, name, output): """Set `output` with `name` to be captured as a non tensor output.""" - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): self._non_tensor_outputs[name] = output else: def merge_fn(distribution, value): # NOTE(priyag): For non tensor outputs, we simply return all the values # in a list as aggregation doesn't make sense on non tensors. self._non_tensor_outputs[name] = distribution.unwrap(value) - distribute_lib.get_tower_context().merge_call(merge_fn, output) + distribution_strategy_context.get_tower_context().merge_call( + merge_fn, output) def value_container(val): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py b/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py index 90910f3839b..200310bc414 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py @@ -173,6 +173,13 @@ class DeterministicTest(test.TestCase): self.assertAllClose( np.zeros(sample_shape_ + (2,)).astype(np.float32), sample_) + def testEntropy(self): + loc = np.array([-0.1, -3.2, 7.]) + deterministic = deterministic_lib.Deterministic(loc=loc) + with self.test_session() as sess: + entropy_ = sess.run(deterministic.entropy()) + self.assertAllEqual(np.zeros(3), entropy_) + class VectorDeterministicTest(test.TestCase): @@ -290,6 +297,13 @@ class VectorDeterministicTest(test.TestCase): self.assertAllClose( np.zeros(sample_shape_ + (2, 1)).astype(np.float32), sample_) + def testEntropy(self): + loc = np.array([[8.3, 1.2, 3.3], [-0.1, -3.2, 7.]]) + deterministic = deterministic_lib.VectorDeterministic(loc=loc) + with self.test_session() as sess: + entropy_ = sess.run(deterministic.entropy()) + self.assertAllEqual(np.zeros(2), entropy_) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py index ad853ee293f..affc64a14f6 100644 --- a/tensorflow/contrib/distributions/python/ops/deterministic.py +++ b/tensorflow/contrib/distributions/python/ops/deterministic.py @@ -152,6 +152,9 @@ class _BaseDeterministic(distribution.Distribution): """Relative tolerance for comparing points to `self.loc`.""" return self._rtol + def _entropy(self): + return array_ops.zeros(self.batch_shape_tensor(), dtype=self.dtype) + def _mean(self): return array_ops.identity(self.loc) diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb index 975105a179f..5621d6a358e 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb @@ -495,7 +495,7 @@ " random_vector_for_generation)\n", " \n", " # saving (checkpoint) the model every 15 epochs\n", - " if epoch % 15 == 0:\n", + " if (epoch + 1) % 15 == 0:\n", " checkpoint.save(file_prefix = checkpoint_prefix)\n", " \n", " print ('Time taken for epoch {} is {} sec'.format(epoch + 1,\n", diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb index 78a711548dd..027097908f2 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb @@ -132,6 +132,7 @@ "tf.enable_eager_execution()\n", "\n", "import numpy as np\n", + "import os\n", "import re\n", "import random\n", "import unidecode\n", @@ -313,7 +314,7 @@ "outputs": [], "source": [ "dataset = tf.data.Dataset.from_tensor_slices((input_text, target_text)).shuffle(BUFFER_SIZE)\n", - "dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))" + "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" ] }, { @@ -493,7 +494,7 @@ "source": [ "# Training step\n", "\n", - "EPOCHS = 30\n", + "EPOCHS = 20\n", "\n", "for epoch in range(EPOCHS):\n", " start = time.time()\n", @@ -520,7 +521,7 @@ " batch,\n", " loss))\n", " # saving (checkpoint) the model every 5 epochs\n", - " if epoch % 5 == 0:\n", + " if (epoch + 1) % 5 == 0:\n", " checkpoint.save(file_prefix = checkpoint_prefix)\n", "\n", " print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))\n", diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb index 1d07721e3b6..08d8364978f 100644 --- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -319,7 +319,7 @@ "vocab_tar_size = len(targ_lang.word2idx)\n", "\n", "dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n", - "dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))" + "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" ] }, { @@ -619,7 +619,7 @@ " batch,\n", " batch_loss.numpy()))\n", " # saving (checkpoint) the model every 2 epochs\n", - " if epoch % 2 == 0:\n", + " if (epoch + 1) % 2 == 0:\n", " checkpoint.save(file_prefix = checkpoint_prefix)\n", " \n", " print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n", diff --git a/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb index acc0f5b6531..ee25d25b52a 100644 --- a/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb +++ b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb @@ -701,7 +701,7 @@ " generate_images(generator, inp, tar)\n", " \n", " # saving (checkpoint) the model every 20 epochs\n", - " if epoch % 20 == 0:\n", + " if (epoch + 1) % 20 == 0:\n", " checkpoint.save(file_prefix = checkpoint_prefix)\n", "\n", " print ('Time taken for epoch {} is {} sec\\n'.format(epoch + 1,\n", diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 82272bf1207..77f62df99d5 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -20,6 +20,7 @@ py_library( ":dnn_linear_combined", ":early_stopping", ":export", + ":exporter", ":extenders", ":head", ":hooks", @@ -219,6 +220,33 @@ py_test( ], ) +py_library( + name = "exporter", + srcs = [ + "python/estimator/exporter.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + "//tensorflow/python/estimator:exporter", + ], +) + +py_test( + name = "exporter_test", + size = "medium", + srcs = ["python/estimator/exporter_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":exporter", + "//tensorflow/python:platform", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:exporter", + ], +) + py_library( name = "head", srcs = [ diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index e1453ae1d04..6ad3a4a6049 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -45,6 +45,7 @@ _allowed_symbols = [ 'clip_gradients_by_norm', 'forward_features', 'InMemoryEvaluatorHook', + 'StopAtCheckpointStepHook', 'logistic_regression_head', 'multi_class_head', 'multi_head', diff --git a/tensorflow/contrib/estimator/python/estimator/exporter.py b/tensorflow/contrib/estimator/python/estimator/exporter.py new file mode 100644 index 00000000000..09d74406056 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/exporter.py @@ -0,0 +1,280 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements StepsExporter to export the model in user specified steps.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.estimator import exporter +from tensorflow.python.framework import ops +from tensorflow.python.platform import gfile +from tensorflow.python.platform import tf_logging +from tensorflow.python.summary import summary_iterator + +DEFAULT_GLOBAL_STEP_KEY = ops.GraphKeys.GLOBAL_STEP + + +class StepsExporter(exporter.Exporter): + """This class exports the model in user specified steps. + + This class exports the model at the steps given by the `steps_to_keep` + argument. Each number in the list is treated as a lower bound for model + exports, to handle the case when evaluation is performed at different steps. + + Consider this example: + + ``` + steps_to_keep = [1, 2, 3, 6, 7, 10, 12, 25] + ``` + + The model is evaluated at step increments of 5: `[5, 10, 15, 20, 25, 30]`. + The `StepsExporter` will export the model when it has reached steps + `[5, 10, 15, 25]`. + + This example illustrates the two cases when the model is exported: + + 1. Model is evaluated on a step defined in the list `steps_to_keep`. + + In the example, the model is exported on step `10` and `25`. + + 2. Model is evaluated on a step not defined in the list `steps_to_keep`, but + is still exported because a step in `steps_to_keep` was missed. + + In the example, when the model reaches step `5`, the model is exported even + though `steps_to_keep` does not contain `5`. Step `5` is exported to make + up for step `3`, which was missed. Steps `1` and `2` in `steps_to_keep` are + skipped completely (e.g. say the model is evaluated at step `6`. It will + **not** be exported to make up for step `2`). + + Using the `steps_to_keep` list as a lower bound allows users to define + approximate step boundaries for exporting their models, and avoid frustrating + off-by-one calculation errors. + + Sample Use Cases: + There are specific points during the training when having a saved version of + the model would be useful. One example is at the end of each training phase + when the set of freezed weights is changed. + Another good use case is saving the model at the end of each epoch for + visualization or retraining. + """ + + def __init__(self, + steps_to_keep, + name='steps_exporter', + serving_input_receiver_fn=None, + event_file_pattern='eval/*.tfevents.*', + assets_extra=None, + as_text=False): + """Create an `StepsExporter` to use with `tf.estimator.EvalSpec`. + + Example of creating a StepsExporter for training and evaluation: + + ```python + categorical_feature_a = categorical_column_with_hash_bucket(...) + categorical_feature_b = categorical_column_with_hash_bucket(...) + + categorical_feature_a_emb = embedding_column( + categorical_column=categorical_feature_a, ...) + categorical_feature_b_emb = embedding_column( + categorical_column=categorical_feature_b, ...) + + estimator = tf.estimator.DNNClassifier( + feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], + hidden_units=[1024, 512, 256]) + + # Input pipeline for train and evaluate. + def train_input_fn: # returns x, y + # please shuffle the data. + pass + def eval_input_fn_eval: # returns x, y + pass + + exporter = tf.contrib.estimator.exporter.StepsExporter( + name="steps_exporter", + serving_input_receiver_fn=serving_input_receiver_fn, + event_file_pattern='eval/*.tfevents.*' + steps_to_keep=[...]) + + train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000) + + eval_spec = [tf.estimator.EvalSpec( + input_fn=eval_input_fn, + steps=1, + exporters=exporter, + start_delay_secs=0, + throttle_secs=5)] + + tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) + + # Models will be exported to estimator.model_dir in timestamped directories, + # which can be used for serving, analysis with TFMA, or directly loaded in. + # For example: + export_dir = os.path.join(estimator.model_dir, + ) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + tf.saved_model.loader.load( + sess, [tf.saved_model.tag_constants.SERVING], export_dir) + + ``` + + Args: + steps_to_keep: Non-empty list of positive integers containing + the step numbers at which the model should be exported. All the exports + will be kept, so there is no garbage collection. + name: Unique name of this `Exporter` that is going to be used in the + export path. + serving_input_receiver_fn: A function that takes no arguments and returns + a `ServingInputReceiver`. + event_file_pattern: Event file name pattern relative to model_dir. If + None, however, the exporter would not be preemption-safe. To be + preemption-safe, event_file_pattern should be specified. + assets_extra: An optional dict specifying how to populate the assets.extra + directory within the exported SavedModel. Each key should give the + destination path (including the filename) relative to the assets.extra + directory. The corresponding value gives the full path of the source + file to be copied. For example, the simple case of copying a single + file without renaming it is specified as `{'my_asset_file.txt': + '/path/to/my_asset_file.txt'}`. + as_text: Whether to write the SavedModel proto in text format. Defaults to + `False`. + + Raises: + ValueError: If any arguments is invalid. + """ + # pylint: disable=protected-access + self._saved_model_exporter = exporter._SavedModelExporter( + name, serving_input_receiver_fn, assets_extra, as_text) + # pylint: enable=protected-access + + self._event_file_pattern = event_file_pattern + self._model_dir = None + + self._input_steps_to_keep = steps_to_keep + steps_to_keep = [step for step in steps_to_keep if isinstance(step, int)] + steps_to_keep = [step for step in steps_to_keep if step > 0] + if not steps_to_keep: + raise ValueError( + '`steps_to_keep` list must have at least one positive integer') + elif self._input_steps_to_keep != steps_to_keep: + tf_logging.warn('Changed `steps_to_keep`, by omitting non-integer or' + ' less than 1 elements, to [%s]', + ', '.join(str(step) for step in steps_to_keep)) + self._steps_to_keep = sorted(steps_to_keep) + self._steps_kept = [] + + @property + def name(self): + return self._saved_model_exporter.name + + def export(self, estimator, export_path, checkpoint_path, eval_result, + is_the_final_export): + """Exports the given Estimator to a specific format. + + Args: + estimator: A `tf.estimator.Estimator` instance to export. + export_path: A string containing a directory where to write the export. + checkpoint_path: The checkpoint path to export. + eval_result: The output of Estimator.evaluate on this checkpoint. + is_the_final_export: This boolean is True when this is an export in the + end of training. It is False for the intermediate exports during the + training. When passing Exporter to tf.estimator.train_and_evaluate + is_the_final_export is always False if TrainSpec.max_steps is None. + + Returns: + The string path to the exported directory or None if export is skipped. + + Raises: + ValueError: If `eval_result` is None or doesn't have + `ops.GraphKeys.GLOBAL_STEP` as a key. + """ + export_result = None + + if not eval_result or DEFAULT_GLOBAL_STEP_KEY not in eval_result: + raise ValueError( + '`eval_result` is empty, or does not have global step. This' + ' should never happen as Estimator always sets the global step in ' + '`eval_result`. Please file a bug report. Got eval_result: %s' + % str(eval_result)) + + if self._model_dir != estimator.model_dir and self._event_file_pattern: + tf_logging.info('Loads the steps that the model was already evaluated at,' + 'from event files') + self._model_dir = estimator.model_dir + full_event_file_pattern = os.path.join(self._model_dir, + self._event_file_pattern) + self._steps_kept = self._get_kept_steps(full_event_file_pattern) + + if self._steps_kept: + self._steps_kept = sorted(self._steps_kept) + self._steps_to_keep = [step for step in self._steps_to_keep if + step > self._steps_kept[-1]] + # It is assumed that the model is exported at any evaluated step 'n' if + # there is any `steps_missed` lower than 'n'. As a result, all the steps in + # `_steps_to_keep` lower than the last evaluated step will be removed. + steps_missed = [step for step in self._steps_to_keep + if step <= eval_result[DEFAULT_GLOBAL_STEP_KEY]] + + if steps_missed: + # update the `_steps_to_keep` list by omitting all steps smaller than the + # current global step which are missed to be exported + export_result = self._saved_model_exporter.export(estimator, export_path, + checkpoint_path, + eval_result, + is_the_final_export) + self._steps_to_keep = [step for step in self._steps_to_keep if step + not in steps_missed] + # contains all the steps in which export has happened. + self._steps_kept.append(eval_result[DEFAULT_GLOBAL_STEP_KEY]) + # Show warning for all the missed steps except the last one + if steps_missed[:-1]: + tf_logging.warn('Missed steps [%s] for exporting, as no evaluation' + ' took place at them.', ', '.join(str(step) for step in + steps_missed[:-1])) + # Log model export if the last missed step is the same as the current step + if steps_missed[-1] == eval_result[DEFAULT_GLOBAL_STEP_KEY]: + tf_logging.info('Performing model export at step %d.', + eval_result[DEFAULT_GLOBAL_STEP_KEY]) + # Show warning for exporting model at another step instead of the user + # specified one + else: + tf_logging.warn('Performing model export at step %d instead of %d, as' + ' no evaluation took place at step %d.', + eval_result[DEFAULT_GLOBAL_STEP_KEY], steps_missed[-1], + steps_missed[-1]) + return export_result + + def _get_kept_steps(self, event_files): + """Get the steps that the model was evaluated at, from event files. + + Args: + event_files: Absolute pattern of event files. + + Returns: + steps_kept: A list of steps in which the model was evaluated. + """ + if not event_files: + return None + + steps_kept = [] + for event_file in gfile.Glob(os.path.join(event_files)): + for event in summary_iterator.summary_iterator(event_file): + if event.step not in steps_kept: + steps_kept.append(event.step) + return steps_kept diff --git a/tensorflow/contrib/estimator/python/estimator/exporter_test.py b/tensorflow/contrib/estimator/python/estimator/exporter_test.py new file mode 100644 index 00000000000..0d009b945e7 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/exporter_test.py @@ -0,0 +1,206 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `StepsExporter`.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import tempfile + +from tensorflow.contrib.estimator.python.estimator import exporter as exporter_lib +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test + + +class StepsExporterTest(test.TestCase): + + def test_error_out_if_steps_to_keep_has_no_positive_integers(self): + + def _serving_input_receiver_fn(): + pass + + with self.assertRaisesRegexp(ValueError, "positive integer"): + exporter = exporter_lib.StepsExporter( + name="specified_steps_exporter", + serving_input_receiver_fn=_serving_input_receiver_fn, + steps_to_keep=[-1, 0, 1.1]) + self.assertEqual("specified_steps_exporter", exporter.name) + + def test_steps_exporter(self): + + def _serving_input_receiver_fn(): + pass + + export_dir_base = tempfile.mkdtemp() + gfile.MkDir(export_dir_base) + gfile.MkDir(export_dir_base + "/export") + gfile.MkDir(export_dir_base + "/eval") + + exporter = exporter_lib.StepsExporter( + name="steps_exporter", + serving_input_receiver_fn=_serving_input_receiver_fn, + assets_extra={"from/path": "to/path"}, + as_text=False, + steps_to_keep=[1]) + estimator = test.mock.Mock(spec=estimator_lib.Estimator) + estimator.export_savedmodel.return_value = "export_result_path" + estimator.model_dir = export_dir_base + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"global_step": 1}, + False) + + self.assertEqual("export_result_path", export_result) + estimator.export_savedmodel.assert_called_with( + export_dir_base, + _serving_input_receiver_fn, + assets_extra={"from/path": "to/path"}, + as_text=False, + checkpoint_path="checkpoint_path", + strip_default_attrs=True) + + shutil.rmtree(export_dir_base, ignore_errors=True) + + def test_steps_exporter_with_preemption(self): + + def _serving_input_receiver_fn(): + pass + + export_dir_base = tempfile.mkdtemp() + gfile.MkDir(export_dir_base) + gfile.MkDir(export_dir_base + "/export") + gfile.MkDir(export_dir_base + "/eval") + + eval_dir_base = os.path.join(export_dir_base, "eval_continuous") + estimator_lib._write_dict_to_summary(eval_dir_base, {}, 1) + estimator_lib._write_dict_to_summary(eval_dir_base, {}, 2) + + exporter = exporter_lib.StepsExporter( + name="steps_exporter", + serving_input_receiver_fn=_serving_input_receiver_fn, + event_file_pattern="eval_continuous/*.tfevents.*", + assets_extra={"from/path": "to/path"}, + as_text=False, + steps_to_keep=[1, 2, 6, 8]) + + estimator = test.mock.Mock(spec=estimator_lib.Estimator) + estimator.model_dir = export_dir_base + estimator.export_savedmodel.return_value = "export_result_path" + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"global_step": 3}, + False) + self.assertEqual(None, export_result) + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"global_step": 6}, + False) + self.assertEqual("export_result_path", export_result) + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"global_step": 7}, + False) + self.assertEqual(None, export_result) + + shutil.rmtree(export_dir_base, ignore_errors=True) + + def test_specified_step_is_saved(self): + + def _serving_input_receiver_fn(): + pass + + export_dir_base = tempfile.mkdtemp() + gfile.MkDir(export_dir_base) + gfile.MkDir(export_dir_base + "/export") + gfile.MkDir(export_dir_base + "/eval") + + exporter = exporter_lib.StepsExporter( + name="steps_exporter", + serving_input_receiver_fn=_serving_input_receiver_fn, + assets_extra={"from/path": "to/path"}, + as_text=False, + steps_to_keep=[1, 5, 8, 10, 11]) + estimator = test.mock.Mock(spec=estimator_lib.Estimator) + estimator.export_savedmodel.return_value = "export_result_path" + estimator.model_dir = export_dir_base + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"global_step": 1}, + False) + + self.assertTrue(estimator.export_savedmodel.called) + self.assertEqual("export_result_path", export_result) + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"global_step": 2}, + False) + self.assertEqual(None, export_result) + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"global_step": 5}, + False) + self.assertTrue(estimator.export_savedmodel.called) + self.assertEqual("export_result_path", export_result) + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"global_step": 10}, + False) + self.assertTrue(estimator.export_savedmodel.called) + self.assertEqual("export_result_path", export_result) + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"global_step": 15}, + False) + self.assertTrue(estimator.export_savedmodel.called) + self.assertEqual("export_result_path", export_result) + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"global_step": 20}, + False) + self.assertEqual(None, export_result) + + shutil.rmtree(export_dir_base, ignore_errors=True) + + def test_steps_exporter_with_no_global_step_key(self): + + def _serving_input_receiver_fn(): + pass + + export_dir_base = tempfile.mkdtemp() + gfile.MkDir(export_dir_base) + gfile.MkDir(export_dir_base + "/export") + gfile.MkDir(export_dir_base + "/eval") + + exporter = exporter_lib.StepsExporter( + name="steps_exporter", + serving_input_receiver_fn=_serving_input_receiver_fn, + assets_extra={"from/path": "to/path"}, + as_text=False, + steps_to_keep=[1]) + estimator = test.mock.Mock(spec=estimator_lib.Estimator) + estimator.export_savedmodel.return_value = "export_result_path" + estimator.model_dir = export_dir_base + + with self.assertRaisesRegexp(ValueError, "does not have global step"): + exporter.export(estimator, export_dir_base, "checkpoint_path", {}, False) + + shutil.rmtree(export_dir_base, ignore_errors=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py index caadafdfa69..faefda7c489 100644 --- a/tensorflow/contrib/estimator/python/estimator/hooks.py +++ b/tensorflow/contrib/estimator/python/estimator/hooks.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import os +import time from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.framework import ops @@ -26,6 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import state_ops from tensorflow.python.training import training +from tensorflow.python.training import training_util # pylint: disable=protected-access @@ -210,4 +212,55 @@ class InMemoryEvaluatorHook(training.SessionRunHook): self._evaluate(session) +class StopAtCheckpointStepHook(training.SessionRunHook): + """Hook that requests stop at a specified step based on checkpoint.""" + + def __init__(self, model_dir, last_step, + wait_after_file_check_secs=30): + """Initializes a `StopAtCheckpointStepHook`. + + This hook requests stop after a last step has been reached. It checks latest + checkpoint to verify last step is written on disk or not. + + Args: + model_dir: Directory to read global step from latest checkpoint. + last_step: Step after which to stop. + wait_after_file_check_secs: Reading same file by many workers may create + I/O issues. To throttle that we will wait given secs after each read of + the file. + + Raises: + ValueError: If one of the arguments is invalid. + """ + if last_step is None: + raise ValueError('last_step must be specified.') + if model_dir is None: + raise ValueError('model_dir must be specified.') + + self._model_dir = model_dir + self._last_step = last_step + self._wait_after_file_check_secs = wait_after_file_check_secs + + def begin(self): + self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access + if self._global_step_tensor is None: + raise RuntimeError( + 'Global step should be created to use StopAtCheckpointStepHook.') + + def before_run(self, run_context): # pylint: disable=unused-argument + return training.SessionRunArgs(self._global_step_tensor) + + def after_run(self, run_context, run_values): + global_step = run_values.results + 1 + if global_step >= self._last_step: + # Check latest global step in the checkpoint to ensure that the targeted + # last step is written on disk. + + step = estimator_lib._load_global_step_from_checkpoint_dir( + self._model_dir) + if step >= self._last_step: + run_context.request_stop() + else: + time.sleep(self._wait_after_file_check_secs) + # pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/hooks_test.py b/tensorflow/contrib/estimator/python/estimator/hooks_test.py index ee88d5ecf50..42352aa3ffb 100644 --- a/tensorflow/contrib/estimator/python/estimator/hooks_test.py +++ b/tensorflow/contrib/estimator/python/estimator/hooks_test.py @@ -21,8 +21,11 @@ from __future__ import print_function import glob import json import os +import tempfile +import time from tensorflow.contrib.estimator.python.estimator import hooks as hooks_lib +from tensorflow.python.client import session as tf_session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator_lib from tensorflow.python.estimator import run_config as run_config_lib @@ -316,5 +319,59 @@ class InMemoryEvaluatorHookTest(test.TestCase): estimator.train(input_fn, hooks=[evaluator]) +class StopAtCheckpointStepHookTest(test.TestCase): + + def test_do_not_stop_if_checkpoint_is_not_there(self): + with ops.Graph().as_default(): + step = training.create_global_step() + assign_ten = step.assign(10) + no_op = control_flow_ops.no_op() + hook = hooks_lib.StopAtCheckpointStepHook( + model_dir=tempfile.mkdtemp(), last_step=10) + with training.SingularMonitoredSession(hooks=[hook]) as mon_sess: + mon_sess.raw_session().run(assign_ten) + with test.mock.patch.object(time, 'sleep') as mock_sleep: + mon_sess.run(no_op) + self.assertTrue(mock_sleep.called) + self.assertFalse(mon_sess.should_stop()) + + def test_do_not_stop_if_checkpoint_step_is_smaller(self): + model_dir = tempfile.mkdtemp() + with ops.Graph().as_default(): + step = training.create_global_step() + assign_nine = step.assign(9) + assign_ten = step.assign(10) + no_op = control_flow_ops.no_op() + hook = hooks_lib.StopAtCheckpointStepHook( + model_dir=model_dir, last_step=10) + with tf_session.Session() as sess: + sess.run(assign_nine) + training.Saver().save(sess, os.path.join(model_dir, 'model.ckpt')) + with training.SingularMonitoredSession(hooks=[hook]) as mon_sess: + mon_sess.raw_session().run(assign_ten) + with test.mock.patch.object(time, 'sleep') as mock_sleep: + mon_sess.run(no_op) + self.assertTrue(mock_sleep.called) + self.assertFalse(mon_sess.should_stop()) + + def test_stop_if_checkpoint_step_is_laststep(self): + model_dir = tempfile.mkdtemp() + with ops.Graph().as_default(): + step = training.create_global_step() + assign_ten = step.assign(10) + no_op = control_flow_ops.no_op() + hook = hooks_lib.StopAtCheckpointStepHook( + model_dir=model_dir, last_step=10) + with tf_session.Session() as sess: + sess.run(assign_ten) + training.Saver().save(sess, os.path.join(model_dir, 'model.ckpt')) + with training.SingularMonitoredSession(hooks=[hook]) as mon_sess: + mon_sess.raw_session().run(assign_ten) + with test.mock.patch.object(time, 'sleep') as mock_sleep: + mon_sess.run(no_op) + self.assertFalse(mock_sleep.called) + self.assertTrue(mon_sess.should_stop()) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index 4d8d5004fe2..f384d761a84 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -188,7 +188,6 @@ class _ModelFn(object): # center. # is_initialized: scalar indicating whether the initial cluster centers # have been chosen; see init_op. - # cluster_centers_var: a Variable containing the cluster centers. # init_op: an op to choose the initial cluster centers. A single worker # repeatedly executes init_op until is_initialized becomes True. # training_op: an op that runs an iteration of training, either an entire diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index 82e3bbe3c01..9866fccfba3 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -424,9 +424,11 @@ py_library( ":namedtuples", "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", + "//tensorflow/python:functional_ops", "//tensorflow/python:math_ops", "//tensorflow/python:summary", "//tensorflow/python:util", + "//tensorflow/python:variable_scope", "//tensorflow/python/ops/losses", ], ) diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py index 508f487722f..f9995bb19d0 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py @@ -22,7 +22,9 @@ from tensorflow.contrib.gan.python import namedtuples from tensorflow.contrib.gan.python.eval.python import eval_utils from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import util as loss_util from tensorflow.python.summary import summary @@ -32,6 +34,7 @@ __all__ = [ 'add_gan_model_summaries', 'add_regularization_loss_summaries', 'add_cyclegan_image_summaries', + 'add_stargan_image_summaries' ] @@ -179,6 +182,94 @@ def add_image_comparison_summaries(gan_model, num_comparisons=2, max_outputs=1) +def add_stargan_image_summaries(stargan_model, + num_images=2, + display_diffs=False): + """Adds image summaries to see StarGAN image results. + + If display_diffs is True, each image result has `2` rows and `num_domains + 1` + columns. + The first row looks like: + [original_image, transformed_to_domain_0, transformed_to_domain_1, ...] + The second row looks like: + [no_modification_baseline, transformed_to_domain_0-original_image, ...] + If display_diffs is False, only the first row is shown. + + IMPORTANT: + Since the model originally does not transformed the image to every domains, + we will transform them on-the-fly within this function in parallel. + + Args: + stargan_model: A StarGANModel tuple. + num_images: The number of examples/images to be transformed and shown. + display_diffs: Also display the difference between generated and target. + + Raises: + ValueError: If input_data is not images. + ValueError: If input_data_domain_label is not rank 2. + ValueError: If dimension 2 of input_data_domain_label is not fully defined. + """ + + _assert_is_image(stargan_model.input_data) + stargan_model.input_data_domain_label.shape.assert_has_rank(2) + stargan_model.input_data_domain_label.shape[1:].assert_is_fully_defined() + + num_domains = stargan_model.input_data_domain_label.get_shape().as_list()[-1] + + def _build_image(image): + """Helper function to create a result for each image on the fly.""" + + # Expand the first dimension as batch_size = 1. + images = array_ops.expand_dims(image, axis=0) + + # Tile the image num_domains times, so we can get all transformed together. + images = array_ops.tile(images, [num_domains, 1, 1, 1]) + + # Create the targets to 0, 1, 2, ..., num_domains-1. + targets = array_ops.one_hot(list(range(num_domains)), num_domains) + + with variable_scope.variable_scope( + stargan_model.generator_scope, reuse=True): + + # Add the original image. + output_images_list = [image] + + # Generate the image and add to the list. + gen_images = stargan_model.generator_fn(images, targets) + gen_images_list = array_ops.split(gen_images, num_domains) + gen_images_list = [ + array_ops.squeeze(img, axis=0) for img in gen_images_list + ] + output_images_list.extend(gen_images_list) + + # Display diffs. + if display_diffs: + diff_images = gen_images - images + diff_images_list = array_ops.split(diff_images, num_domains) + diff_images_list = [ + array_ops.squeeze(img, axis=0) for img in diff_images_list + ] + output_images_list.append(array_ops.zeros_like(image)) + output_images_list.extend(diff_images_list) + + # Create the final image. + final_image = eval_utils.image_reshaper( + output_images_list, num_cols=num_domains + 1) + + # Reduce the first rank. + return array_ops.squeeze(final_image, axis=0) + + summary.image( + 'stargan_image_generation', + functional_ops.map_fn( + _build_image, + stargan_model.input_data[:num_images], + parallel_iterations=num_images, + back_prop=False, + swap_memory=True), + max_outputs=num_images) + + def add_gan_model_summaries(gan_model): """Adds typical GANModel summaries. diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py index 33d51bfc218..54a6f8d4d90 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - from tensorflow.contrib.gan.python import namedtuples from tensorflow.contrib.gan.python.eval.python import summaries_impl as summaries from tensorflow.python.framework import ops @@ -37,6 +36,10 @@ def discriminator_model(inputs, _): return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs +def stargan_generator_model(inputs, _): + return generator_model(inputs) + + def get_gan_model(): # TODO(joelshor): Find a better way of creating a variable scope. with variable_scope.variable_scope('generator') as gen_scope: @@ -57,6 +60,31 @@ def get_gan_model(): discriminator_fn=discriminator_model) +def get_stargan_model(): + """Similar to get_gan_model().""" + # TODO(joelshor): Find a better way of creating a variable scope. + with variable_scope.variable_scope('discriminator') as dis_scope: + pass + with variable_scope.variable_scope('generator') as gen_scope: + return namedtuples.StarGANModel( + input_data=array_ops.ones([1, 2, 2, 3]), + input_data_domain_label=array_ops.ones([1, 2]), + generated_data=stargan_generator_model( + array_ops.ones([1, 2, 2, 3]), None), + generated_data_domain_target=array_ops.ones([1, 2]), + reconstructed_data=array_ops.ones([1, 2, 2, 3]), + discriminator_input_data_source_predication=array_ops.ones([1]), + discriminator_generated_data_source_predication=array_ops.ones([1]), + discriminator_input_data_domain_predication=array_ops.ones([1, 2]), + discriminator_generated_data_domain_predication=array_ops.ones([1, 2]), + generator_variables=None, + generator_scope=gen_scope, + generator_fn=stargan_generator_model, + discriminator_variables=None, + discriminator_scope=dis_scope, + discriminator_fn=discriminator_model) + + def get_cyclegan_model(): with variable_scope.variable_scope('x2y'): model_x2y = get_gan_model() @@ -143,6 +171,16 @@ class SummariesTest(test.TestCase): with self.test_session(use_gpu=True): summary.merge_all().eval() + def test_add_image_comparison_summaries_for_stargan(self): + + summaries.add_stargan_image_summaries(get_stargan_model()) + + self.assertEquals(1, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + + with self.test_session(use_gpu=True) as sess: + sess.run(variables.global_variables_initializer()) + summary.merge_all().eval() + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 03f52d214b5..9e5aea1498a 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -52,7 +52,6 @@ from tensorflow.python.training import session_run_hook from tensorflow.python.training import sync_replicas_optimizer from tensorflow.python.training import training_util - __all__ = [ 'gan_model', 'infogan_model', @@ -61,6 +60,7 @@ __all__ = [ 'stargan_model', 'gan_loss', 'cyclegan_loss', + 'stargan_loss', 'gan_train_ops', 'gan_train', 'get_sequential_train_hooks', @@ -646,8 +646,9 @@ def gan_loss( type(model)) # Optionally create pooled model. - pooled_model = (_tensor_pool_adjusted_model(model, tensor_pool_fn) if - tensor_pool_fn else model) + pooled_model = ( + _tensor_pool_adjusted_model(model, tensor_pool_fn) + if tensor_pool_fn else model) # Create standard losses. gen_loss = generator_loss_fn(model, add_summaries=add_summaries) @@ -665,9 +666,10 @@ def gan_loss( if _use_aux_loss(mutual_information_penalty_weight): gen_info_loss = tfgan_losses.mutual_information_penalty( model, add_summaries=add_summaries) - dis_info_loss = (gen_info_loss if tensor_pool_fn is None else - tfgan_losses.mutual_information_penalty( - pooled_model, add_summaries=add_summaries)) + dis_info_loss = ( + gen_info_loss + if tensor_pool_fn is None else tfgan_losses.mutual_information_penalty( + pooled_model, add_summaries=add_summaries)) gen_loss += mutual_information_penalty_weight * gen_info_loss dis_loss += mutual_information_penalty_weight * dis_info_loss if _use_aux_loss(aux_cond_generator_weight): diff --git a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc index b510994152b..80b2d3e08b6 100644 --- a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc +++ b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc @@ -204,11 +204,11 @@ class SequenceFileDatasetOp : public DatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const std::vector& filenames, const DataTypeVector& output_types) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), filenames_(filenames), output_types_(output_types) {} @@ -233,7 +233,8 @@ class SequenceFileDatasetOp : public DatasetOpKernel { } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* filenames = nullptr; TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); diff --git a/tensorflow/contrib/integrate/python/ops/odes.py b/tensorflow/contrib/integrate/python/ops/odes.py index 61f78febfc0..7b7ac4f347e 100644 --- a/tensorflow/contrib/integrate/python/ops/odes.py +++ b/tensorflow/contrib/integrate/python/ops/odes.py @@ -73,7 +73,7 @@ def _scaled_dot_product(scale, xs, ys, name=None): # _possibly_nonzero lets us avoid wasted computation. return math_ops.add_n( [(scale * x) * y for x, y in zip(xs, ys) - if _possibly_nonzero(x) or _possibly_nonzero(y)], + if _possibly_nonzero(x) and _possibly_nonzero(y)], name=scope) @@ -122,7 +122,7 @@ def _runge_kutta_step(func, yi = y0 + _scaled_dot_product(dt_cast, beta_i, k) k.append(func(yi, ti)) - if not (tableau.c_sol[-1] == 0 and tableau.c_sol == tableau.beta[-1]): + if not (tableau.c_sol[-1] == 0 and tableau.c_sol[:-1] == tableau.beta[-1]): # This property (true for Dormand-Prince) lets us save a few FLOPs. yi = y0 + _scaled_dot_product(dt_cast, tableau.c_sol, k) diff --git a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc index 92ae79d3c7a..d0ea961473c 100644 --- a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc +++ b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc @@ -52,12 +52,12 @@ class KafkaDatasetOp : public DatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, std::vector topics, const string& servers, const string& group, const bool eof, const int64 timeout) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), topics_(std::move(topics)), servers_(servers), group_(group), diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py index d6b1a61b716..44e01e1aebf 100644 --- a/tensorflow/contrib/kfac/examples/convnet.py +++ b/tensorflow/contrib/kfac/examples/convnet.py @@ -202,7 +202,7 @@ def minimize_loss_single_machine(loss, accuracy: 0-D Tensor. Accuracy of classifier on current minibatch. layer_collection: LayerCollection instance describing model architecture. Used by K-FAC to construct preconditioner. - device: string, Either '/cpu:0' or '/gpu:0'. The covaraince and invserse + device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse update ops are run on this device. session_config: None or tf.ConfigProto. Configuration for tf.Session(). @@ -470,7 +470,7 @@ def train_mnist_single_machine(data_dir, data_dir: string. Directory to read MNIST examples from. num_epochs: int. Number of passes to make over the training set. use_fake_data: bool. If True, generate a synthetic dataset. - device: string, Either '/cpu:0' or '/gpu:0'. The covaraince and inverse + device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse update ops are run on this device. Returns: @@ -509,7 +509,7 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers, num_epochs: int. Number of passes to make over the training set. num_towers: int. Number of CPUs to split inference across. use_fake_data: bool. If True, generate a synthetic dataset. - devices: string, Either list of CPU or GPU. The covaraince and inverse + devices: string, Either list of CPU or GPU. The covariance and inverse update ops are run on this device. Returns: @@ -621,7 +621,7 @@ def train_mnist_distributed_sync_replicas(task_id, data_dir: string. Directory to read MNIST examples from. num_epochs: int. Number of passes to make over the training set. op_strategy: `string`, Strategy to run the covariance and inverse - ops. If op_strategy == `chief_worker` then covaraiance and inverse + ops. If op_strategy == `chief_worker` then covariance and inverse update ops are run on chief worker otherwise they are run on dedicated workers. diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index 854f885c26f..323234c4031 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -97,8 +97,8 @@ class FisherEstimator(object): and to regularize the update direction by making it closer to the gradient. (Higher damping means the update looks more like a standard gradient update - see Tikhonov regularization.) - layer_collection: The layer collection object, which holds the fisher - blocks, kronecker factors, and losses associated with the + layer_collection: The layer collection object, which holds the Fisher + blocks, Kronecker factors, and losses associated with the graph. exps: List of floats or ints. These represent the different matrix powers of the approximate Fisher that the FisherEstimator will be able @@ -464,7 +464,7 @@ class FisherEstimator(object): def _get_grads_lists_empirical(self, tensors): # Passing in a list of loss values is better than passing in the sum as - # the latter creates unnessesary ops on the default device + # the latter creates unnecessary ops on the default device grads_flat = gradients_impl.gradients( self._layers.eval_losses(), nest.flatten(tensors), diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index 3a5c8eb5f96..9fa6eb7dcd1 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -870,7 +870,7 @@ class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB): Estimates the Fisher Information matrix's blog for a convolutional layer. - Consider a convoluational layer in this model with (unshared) filter matrix + Consider a convolutional layer in this model with (unshared) filter matrix 'w'. For a minibatch that produces inputs 'a' and output preactivations 's', this FisherBlock estimates, diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index b43232dfafa..afa2fd1ca72 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -71,15 +71,15 @@ _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = 1 # factor. This parameter is used only if `_SUB_SAMPLE_INPUTS` is True. _INPUTS_TO_EXTRACT_PATCHES_FACTOR = 0.5 -# If True, then subsamples the tensor passed to compute the covaraince matrix. +# If True, then subsamples the tensor passed to compute the covariance matrix. _SUB_SAMPLE_OUTER_PRODUCTS = False -# If True, then subsamples the tensor passed to compute the covaraince matrix. +# If True, then subsamples the tensor passed to compute the covariance matrix. _SUB_SAMPLE_INPUTS = False # TOWER_STRATEGY can be one of "concat" or "separate". If "concat", the data # passed to the factors from the blocks will be concatenated across towers -# (lazilly via PartitionedTensor objects). Otherwise a tuple of tensors over +# (lazily via PartitionedTensor objects). Otherwise a tuple of tensors over # towers will be passed in, and the factors will iterate over this and do the # cov computations separately for each one, averaging the results together. TOWER_STRATEGY = "concat" @@ -309,7 +309,7 @@ def _subsample_for_cov_computation(array, name=None): def _random_tensor_gather(array, max_size): - """Generates a random set of indices and gathers the value at the indcices. + """Generates a random set of indices and gathers the value at the indices. Args: array: Tensor, of shape `[batch_size, dim_2]`. @@ -1762,8 +1762,8 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor): # Might need to enforce symmetry lost due to numerical issues. invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0 - # The following line imposses the symmetry assumed by "Option 1" on C1. - # Stangely the code can work okay with this line commented out, + # The following line imposes the symmetry assumed by "Option 1" on C1. + # Strangely the code can work okay with this line commented out, # depending on how psd_eig is defined. I'm not sure why. C1 = (C1 + array_ops.transpose(C1)) / 2.0 diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index cbbfe7212c9..43aa713edcb 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -609,7 +609,7 @@ class LayerCollection(object): outputs, approx=None, reuse=VARIABLE_SCOPE): - """Registers a fully connnected layer. + """Registers a fully connected layer. Args: params: Tensor or 2-tuple of Tensors corresponding to weight and bias of @@ -975,7 +975,7 @@ class LayerCollection(object): block for this layer (which must have already been registered). If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the word `use` here has a completely different meaning to "use in the graph" - as it perturns to the `inputs`, `outputs`, and `num_uses` arguments.) + as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.) (Default: "VARIABLE_SCOPE") Raises: @@ -1045,7 +1045,7 @@ class LayerCollection(object): block for this layer (which must have already been registered). If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the word `use` here has a completely different meaning to "use in the graph" - as it perturns to the `inputs`, `outputs`, and `num_uses` arguments.) + as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.) (Default: "VARIABLE_SCOPE") Raises: @@ -1116,7 +1116,7 @@ class LayerCollection(object): block for this layer (which must have already been registered). If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the word `use` here has a completely different meaning to "use in the graph" - as it perturns to the `inputs`, `outputs`, and `num_uses` arguments.) + as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.) (Default: "VARIABLE_SCOPE") Raises: diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py index 42d525c2c21..c8cebc42cb3 100644 --- a/tensorflow/contrib/kfac/python/ops/loss_functions.py +++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py @@ -214,7 +214,7 @@ class NegativeLogProbLoss(LossFunction): Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- product of gradients) with respect to the parameters of the underlying - probability distribtion (whose log-prob defines the loss). Typically this + probability distribution (whose log-prob defines the loss). Typically this will be block-diagonal across different cases in the batch, since the distribution is usually (but not always) conditionally iid across different cases. @@ -238,7 +238,7 @@ class NegativeLogProbLoss(LossFunction): Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- product of gradients) with respect to the parameters of the underlying - probability distribtion (whose log-prob defines the loss). Typically this + probability distribution (whose log-prob defines the loss). Typically this will be block-diagonal across different cases in the batch, since the distribution is usually (but not always) conditionally iid across different cases. @@ -262,7 +262,7 @@ class NegativeLogProbLoss(LossFunction): Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- product of gradients) with respect to the parameters of the underlying - probability distribtion (whose log-prob defines the loss). Typically this + probability distribution (whose log-prob defines the loss). Typically this will be block-diagonal across different cases in the batch, since the distribution is usually (but not always) conditionally iid across different cases. diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index 03b9da79330..38605259b5f 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -72,7 +72,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): (Higher damping means the update looks more like a standard gradient update - see Tikhonov regularization.) layer_collection: The layer collection object, which holds the fisher - blocks, kronecker factors, and losses associated with the + blocks, Kronecker factors, and losses associated with the graph. The layer_collection cannot be modified after KfacOptimizer's initialization. var_list: Optional list or tuple of variables to train. Defaults to the @@ -99,7 +99,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): placement_strategy: string, Device placement strategy used when creating covariance variables, covariance ops, and inverse ops. (Default: `None`) - **kwargs: Arguments to be passesd to specific placement + **kwargs: Arguments to be passed to specific placement strategy mixin. Check `placement.RoundRobinPlacementMixin` for example. Raises: @@ -120,7 +120,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): self._estimation_mode = estimation_mode self._colocate_gradients_with_ops = colocate_gradients_with_ops - # The below parameters are required only if damping needs to be adapated. + # The below parameters are required only if damping needs to be adapted. # These parameters can be set by calling # set_damping_adaptation_params() explicitly. self._damping_adaptation_decay = 0.95 @@ -574,7 +574,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): """Wrapper function for `self._compute_qmodel_hyperparams`. Constructs a list of preconditioned gradients and variables. Also creates a - op to asssign the computed q model change to `self._q_model_change`. + op to assign the computed q model change to `self._q_model_change`. Args: grads_and_vars: List of (gradient, variable) pairs. diff --git a/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc b/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc index 7b28bb5e4db..95c7001371a 100644 --- a/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc +++ b/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc @@ -164,11 +164,11 @@ class KinesisDatasetOp : public DatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const string& stream, const string& shard, const bool read_indefinitely, const int64 interval) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), stream_(stream), shard_(shard), read_indefinitely_(read_indefinitely), diff --git a/tensorflow/contrib/layers/python/layers/initializers.py b/tensorflow/contrib/layers/python/layers/initializers.py index 1192198ec26..655f038b184 100644 --- a/tensorflow/contrib/layers/python/layers/initializers.py +++ b/tensorflow/contrib/layers/python/layers/initializers.py @@ -111,7 +111,7 @@ def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False, if not dtype.is_floating: raise TypeError('Cannot create initializer for non-floating point type.') if mode not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG']: - raise TypeError('Unknow mode %s [FAN_IN, FAN_OUT, FAN_AVG]', mode) + raise TypeError('Unknown mode %s [FAN_IN, FAN_OUT, FAN_AVG]', mode) # pylint: disable=unused-argument def _initializer(shape, dtype=dtype, partition_info=None): diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index d3aa3fa92c3..418b0cf3920 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -108,7 +108,6 @@ py_test( size = "small", srcs = ["python/learn/learn_io/data_feeder_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":learn", "//tensorflow/python:client_testlib", @@ -164,7 +163,6 @@ tf_py_test( "//tensorflow/python:variables", "//tensorflow/python/estimator:estimator_py", ], - tags = ["no_windows"], # TODO: needs investigation on Windows ) py_test( @@ -591,7 +589,6 @@ py_test( size = "small", srcs = ["python/learn/learn_io/io_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":learn", "//tensorflow/contrib/learn/python/learn/datasets", diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py index 9872c6f97c8..8ebe45d8510 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py @@ -158,7 +158,7 @@ class SDCAOptimizer(object): # exactly 2 (i.e., its shape should be [batch_size, column.dim]). check_rank_op = control_flow_ops.Assert( math_ops.less_equal(array_ops.rank(transformed_tensor), 2), - ['transformed_tensor shouls have rank at most 2.']) + ['transformed_tensor should have rank at most 2.']) # Reshape to [batch_size, dense_column_dimension]. with ops.control_dependencies([check_rank_op]): transformed_tensor = array_ops.reshape(transformed_tensor, [ @@ -172,7 +172,7 @@ class SDCAOptimizer(object): elif isinstance(column, layers.feature_column._BucketizedColumn): # pylint: disable=protected-access # A bucketized column corresponds to a sparse feature in SDCA. The # bucketized feature is "sparsified" for SDCA by converting it to a - # SparseFeatureColumn respresenting the one-hot encoding of the + # SparseFeatureColumn representing the one-hot encoding of the # bucketized feature. # # TODO(sibyl-vie3Poto): Explore whether it is more efficient to translate a @@ -220,7 +220,7 @@ class SDCAOptimizer(object): # occur multiple times for a single example. projected_ids = projection_length * example_ids + flat_ids - # Remove any redudant ids. + # Remove any redundant ids. ids, idx = array_ops.unique(projected_ids) # Keep only one example id per duplicated ids. example_ids_filtered = math_ops.unsorted_segment_min( diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 81844756bc7..ab694d768f9 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -227,6 +227,7 @@ def generated_test_models(): "constant", "control_dep", "conv", + "conv_with_shared_weights", "depthwiseconv", "div", "equal", diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index 5bc20106d31..c920f6a508b 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -452,13 +452,15 @@ typedef struct _TfLiteDelegate { // Copy the data from delegate buffer handle to raw memory. // This can be null if the delegate doesn't use its own buffer. - TfLiteStatus (*CopyFromBufferHandle)(TfLiteDelegate* delegate, + TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context, + TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle, void* data, size_t size); // Copy the data from raw memory to delegate buffer handle. // This can be null if the delegate doesn't use its own buffer. - TfLiteStatus (*CopyToBufferHandle)(TfLiteDelegate* delegate, + TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context, + TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle, void* data, size_t size); @@ -466,7 +468,7 @@ typedef struct _TfLiteDelegate { // this doesn't release the underlying resource (e.g. textures). The // resources are either owned by application layer or the delegate. // This can be null if the delegate doesn't use its own buffer. - void (*FreeBufferHandle)(TfLiteDelegate* delegate, + void (*FreeBufferHandle)(TfLiteContext* context, TfLiteDelegate* delegate, TfLiteBufferHandle* handle); } TfLiteDelegate; diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD index bb518becc58..5a7eb370f6c 100644 --- a/tensorflow/contrib/lite/delegates/eager/BUILD +++ b/tensorflow/contrib/lite/delegates/eager/BUILD @@ -18,18 +18,21 @@ cc_library( "//tensorflow/c:c_api_internal", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:kernel_api", - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - ], + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + ], + }), ) tf_cc_test( name = "buffer_map_test", size = "small", srcs = ["buffer_map_test.cc"], - tags = [ - "tflite_not_portable", - ], deps = [ ":buffer_map", "//tensorflow/contrib/lite:framework", @@ -55,17 +58,20 @@ cc_library( "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:kernel_api", "//tensorflow/contrib/lite:util", - "//tensorflow/core:lib", - ], + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:lib", + ], + }), ) tf_cc_test( name = "delegate_test", size = "small", srcs = ["delegate_test.cc"], - tags = [ - "tflite_not_portable", - ], deps = [ ":delegate", ":test_util", @@ -80,19 +86,22 @@ cc_library( hdrs = ["delegate_data.h"], deps = [ ":buffer_map", - "//tensorflow/core:core_cpu", - "//tensorflow/core:lib", "//tensorflow/core/common_runtime/eager:context", - ], + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + ], + }), ) tf_cc_test( name = "delegate_data_test", size = "small", srcs = ["delegate_data_test.cc"], - tags = [ - "tflite_not_portable", - ], deps = [ ":delegate_data", "//tensorflow/contrib/lite:framework", @@ -109,25 +118,28 @@ cc_library( deps = [ ":delegate_data", ":util", + "@flatbuffers", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:kernel_api", "//tensorflow/contrib/lite:string", "//tensorflow/contrib/lite/kernels:kernel_util", - "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:execute", "//tensorflow/core/common_runtime/eager:tensor_handle", - "@flatbuffers", - ], + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:protos_all_cc", + ], + }), ) tf_cc_test( name = "kernel_test", size = "small", srcs = ["kernel_test.cc"], - tags = [ - "tflite_not_portable", - ], deps = [ ":delegate_data", ":kernel", @@ -159,18 +171,21 @@ cc_library( "//tensorflow/c:c_api_internal", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:kernel_api", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - ], + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:lib", + "//tensorflow/core:framework", + ], + }), ) tf_cc_test( name = "util_test", size = "small", srcs = ["util_test.cc"], - tags = [ - "tflite_not_portable", - ], deps = [ ":util", "//tensorflow/contrib/lite:string", diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.cc b/tensorflow/contrib/lite/delegates/eager/delegate.cc index 7d22b454199..8ab768575e8 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate.cc +++ b/tensorflow/contrib/lite/delegates/eager/delegate.cc @@ -55,17 +55,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) { return kTfLiteOk; } -TfLiteStatus CopyFromBufferHandle(TfLiteDelegate* delegate, +TfLiteStatus CopyFromBufferHandle(TfLiteContext* context, + TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle, void* data, size_t size) { - // TODO(nupurgarg): Make BufferMap unique to each interpreter in order to - // support multiple interpreters using a single delegate. BufferMap* buffer_map = - reinterpret_cast(delegate->data_)->GetBufferMap(); + reinterpret_cast(delegate->data_)->GetBufferMap(context); - // TODO(nupurgarg): Use TfLiteContext's ReportError instead of fprinf. if (!buffer_map->HasTensor(buffer_handle)) { - fprintf(stderr, "Invalid tensor index %d.\n", buffer_handle); + context->ReportError(context, "Invalid tensor index %d.", buffer_handle); return kTfLiteError; } @@ -73,7 +71,8 @@ TfLiteStatus CopyFromBufferHandle(TfLiteDelegate* delegate, tensorflow::StringPiece t_data = t.tensor_data(); if (size != t_data.size()) { - fprintf(stderr, "Not enough space to store TensorFlow's aligned buffer.\n"); + context->ReportError( + context, "Not enough space to store TensorFlow's aligned buffer."); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.h b/tensorflow/contrib/lite/delegates/eager/delegate.h index 0defca7c323..a07002f4870 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate.h +++ b/tensorflow/contrib/lite/delegates/eager/delegate.h @@ -26,8 +26,8 @@ namespace tflite { // executed by TensorFlow's runtime via Eager. // // The interpreter must be constructed after the EagerDelegate and destructed -// before the EagerDelegate. This delegate can only be used with one -// interpreter. +// before the EagerDelegate. This delegate may be used with multiple +// interpreters, but it is *not* thread-safe. // // Usage: // EagerDelegate delegate; diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data.h b/tensorflow/contrib/lite/delegates/eager/delegate_data.h index 8a0e8ba8bf2..772d26f44e8 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate_data.h +++ b/tensorflow/contrib/lite/delegates/eager/delegate_data.h @@ -32,14 +32,18 @@ class DelegateData { // The EagerContext that is required for execution of Eager Ops. tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); } - // Map from TF Lite tensor index to TensorFlow tensor. - BufferMap* GetBufferMap() { return &buffer_map_; } + // Map from TF Lite tensor index to TensorFlow tensor for a given context. + BufferMap* GetBufferMap(const TfLiteContext* context) { + return &buffer_map_[context]; + } private: explicit DelegateData(tensorflow::EagerContext* eager_context); std::unique_ptr eager_context_; - BufferMap buffer_map_; + // TODO(b/112439500): Clean up stale BufferMap instances after adding the + // necessary cleanup hook from a TfLiteContext to a TfLiteDelegate. + std::unordered_map buffer_map_; }; } // namespace eager diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc index 30251b8f82c..b3a0ffcec1d 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc +++ b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/testing/util.h" namespace tflite { @@ -29,8 +30,12 @@ TEST(DelegateDataTest, Basic) { // binary. EXPECT_TRUE(DelegateData::Create(&data).ok()); + TfLiteContext dummy_context1 = {}; + TfLiteContext dummy_context2 = {}; EXPECT_NE(data->GetEagerContext(), nullptr); - EXPECT_NE(data->GetBufferMap(), nullptr); + EXPECT_NE(data->GetBufferMap(&dummy_context1), nullptr); + EXPECT_NE(data->GetBufferMap(&dummy_context1), + data->GetBufferMap(&dummy_context2)); } } // namespace diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc index 88fb34044ec..511a239363e 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc +++ b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc @@ -25,8 +25,6 @@ namespace { using ::testing::ContainsRegex; using ::testing::ElementsAre; -// TODO(nupurgarg): Add a test with multiple interpreters for one delegate. - class DelegateTest : public testing::EagerModelTest { public: DelegateTest() { @@ -139,6 +137,56 @@ TEST_F(DelegateTest, OnlyTFLite) { ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f)); } +TEST_F(DelegateTest, MultipleInterpretersSameDelegate) { + // Build a graph, configure the delegate and set inputs. + { + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfOp(testing::kMul, {6, 7}, {8}); + ConfigureDelegate(); + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + } + + // Create a new interpreter, inject into the test framework and build + // a different graph using the *same* delegate. + std::unique_ptr interpreter(new Interpreter(&error_reporter_)); + interpreter_.swap(interpreter); + { + AddTensors(10, {0}, {9}, kTfLiteFloat32, {3}); + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kAdd, {1, 2}, {3}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfLiteMulOp({4, 5}, {6}); + AddTfOp(testing::kUnpack, {6}, {7, 8}); + AddTfOp(testing::kAdd, {7, 8}, {9}); + ConfigureDelegate(); + SetShape(0, {2, 2, 2, 1}); + SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f}); + } + + // Swap back in the first interpreter and validate inference. + interpreter_.swap(interpreter); + { + ASSERT_TRUE(Invoke()); + EXPECT_THAT(GetShape(8), ElementsAre(2, 1)); + EXPECT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); + } + + // Swap in the second interpreter and validate inference. + interpreter_.swap(interpreter); + { + ASSERT_TRUE(Invoke()); + EXPECT_THAT(GetShape(9), ElementsAre(1)); + EXPECT_THAT(GetValues(9), ElementsAre(10.0f)); + } +} + } // namespace } // namespace eager } // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/eager/kernel.cc index 1bd17a3bcae..1082b787259 100644 --- a/tensorflow/contrib/lite/delegates/eager/kernel.cc +++ b/tensorflow/contrib/lite/delegates/eager/kernel.cc @@ -150,8 +150,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { op_data->eager_context = reinterpret_cast(params->delegate->data_) ->GetEagerContext(); - op_data->buffer_map = - reinterpret_cast(params->delegate->data_)->GetBufferMap(); + op_data->buffer_map = reinterpret_cast(params->delegate->data_) + ->GetBufferMap(context); CHECK(params->output_tensors); for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) { diff --git a/tensorflow/contrib/lite/delegates/eager/kernel_test.cc b/tensorflow/contrib/lite/delegates/eager/kernel_test.cc index b7bfbb34e49..66f22266266 100644 --- a/tensorflow/contrib/lite/delegates/eager/kernel_test.cc +++ b/tensorflow/contrib/lite/delegates/eager/kernel_test.cc @@ -55,12 +55,14 @@ class KernelTest : public testing::EagerModelTest { delegate_.data_ = delegate_data_.get(); delegate_.FreeBufferHandle = nullptr; delegate_.Prepare = prepare_function; - delegate_.CopyFromBufferHandle = [](TfLiteDelegate* delegate, + delegate_.CopyFromBufferHandle = [](TfLiteContext* context, + TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle, void* data, size_t size) { auto* delegate_data = reinterpret_cast(delegate->data_); - tensorflow::StringPiece values = - delegate_data->GetBufferMap()->GetTensor(buffer_handle).tensor_data(); + tensorflow::StringPiece values = delegate_data->GetBufferMap(context) + ->GetTensor(buffer_handle) + .tensor_data(); memcpy(data, values.data(), values.size()); return kTfLiteOk; }; diff --git a/tensorflow/contrib/lite/error_reporter.cc b/tensorflow/contrib/lite/error_reporter.cc index 03fcd5409ce..646913c0262 100644 --- a/tensorflow/contrib/lite/error_reporter.cc +++ b/tensorflow/contrib/lite/error_reporter.cc @@ -16,6 +16,10 @@ limitations under the License. #include #include +#ifdef __ANDROID__ +#include +#endif + namespace tflite { ErrorReporter::~ErrorReporter() {} @@ -39,6 +43,15 @@ int ErrorReporter::ReportError(void*, const char* format, ...) { } int StderrReporter::Report(const char* format, va_list args) { +#ifdef __ANDROID__ + // On Android stderr is not captured for applications, only for code run from + // the shell. Rather than assume all users will set up a custom error + // reporter, let's output to logcat here + va_list args_for_log; + va_copy(args_for_log, args); + __android_log_vprint(ANDROID_LOG_ERROR, "tflite", format, args_for_log); + va_end(args_for_log); +#endif const int result = vfprintf(stderr, format, args); fputc('\n', stderr); return result; diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm index 30fee64a6f6..734b15e0a10 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm @@ -26,7 +26,7 @@ #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/string_util.h" -#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" +#include "tensorflow/contrib/lite/op_resolver.h" #define LOG(x) std::cerr diff --git a/tensorflow/contrib/lite/examples/ios/camera/Podfile b/tensorflow/contrib/lite/examples/ios/camera/Podfile index cd8c39043f6..8084307ac79 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/Podfile +++ b/tensorflow/contrib/lite/examples/ios/camera/Podfile @@ -2,4 +2,4 @@ platform :ios, '8.0' inhibit_all_warnings! target 'tflite_camera_example' - pod 'TensorFlowLite', '0.1.7' + pod 'TensorFlowLite', '1.10.0' diff --git a/tensorflow/contrib/lite/examples/ios/simple/Podfile b/tensorflow/contrib/lite/examples/ios/simple/Podfile index c885398f444..eea7ecb7596 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/Podfile +++ b/tensorflow/contrib/lite/examples/ios/simple/Podfile @@ -2,4 +2,4 @@ platform :ios, '8.0' inhibit_all_warnings! target 'tflite_simple_example' - pod 'TensorFlowLite', '0.1.7' + pod 'TensorFlowLite', '1.10.0' diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm index 0ab7aa25d0b..650c73f7322 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm @@ -25,7 +25,7 @@ #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/string_util.h" -#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" +#include "tensorflow/contrib/lite/op_resolver.h" #include "ios_image_load.h" diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 7a680f5c640..362e5887257 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -157,7 +157,7 @@ Interpreter::~Interpreter() { TfLiteTensor* tensor = &context_.tensors[i]; if (tensor->buffer_handle != kTfLiteNullBufferHandle && tensor->delegate->FreeBufferHandle != nullptr) { - tensor->delegate->FreeBufferHandle(tensor->delegate, + tensor->delegate->FreeBufferHandle(&context_, tensor->delegate, &tensor->buffer_handle); } TfLiteTensorFree(tensor); @@ -988,7 +988,7 @@ TfLiteStatus Interpreter::SetBufferHandle(int tensor_index, tensor->delegate = delegate; if (tensor->buffer_handle != kTfLiteNullBufferHandle) { TF_LITE_ENSURE(&context_, tensor->delegate->FreeBufferHandle != nullptr); - tensor->delegate->FreeBufferHandle(tensor->delegate, + tensor->delegate->FreeBufferHandle(&context_, tensor->delegate, &tensor->buffer_handle); } tensor->buffer_handle = buffer_handle; diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 159ff7bc20a..a27df4b964c 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -350,7 +350,7 @@ class Interpreter { // This can be null if the delegate doesn't use its own buffer. TF_LITE_ENSURE(&context_, tensor->delegate->CopyFromBufferHandle != nullptr); - tensor->delegate->CopyFromBufferHandle(tensor->delegate, + tensor->delegate->CopyFromBufferHandle(&context_, tensor->delegate, tensor->buffer_handle, tensor->data.raw, tensor->bytes); tensor->data_is_stale = false; diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index 2bf598bad71..f00697826c0 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -1080,21 +1080,22 @@ class TestDelegate : public ::testing::Test { return kTfLiteOk; }; delegate_.CopyToBufferHandle = - [](TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle, - void* data, size_t size) -> TfLiteStatus { + [](TfLiteContext* context, TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, void* data, + size_t size) -> TfLiteStatus { // TODO(ycling): Implement tests to test buffer copying logic. return kTfLiteOk; }; delegate_.CopyFromBufferHandle = - [](TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle, - void* data, size_t size) -> TfLiteStatus { + [](TfLiteContext* context, TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, void* data, + size_t size) -> TfLiteStatus { // TODO(ycling): Implement tests to test buffer copying logic. return kTfLiteOk; }; - delegate_.FreeBufferHandle = [](TfLiteDelegate* delegate, - TfLiteBufferHandle* handle) { - *handle = kTfLiteNullBufferHandle; - }; + delegate_.FreeBufferHandle = + [](TfLiteContext* context, TfLiteDelegate* delegate, + TfLiteBufferHandle* handle) { *handle = kTfLiteNullBufferHandle; }; // Store type-punned data SimpleDelegate structure. delegate_.data_ = reinterpret_cast(this); } diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc index 817266a4714..d6d62580e2d 100644 --- a/tensorflow/contrib/lite/kernels/activations.cc +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -40,6 +40,11 @@ struct OpData { int diff_min = 0; }; +struct LogSoftmaxOpData : public OpData { + int32_t reverse_scaling_divisor = 0; + int32_t reverse_scaling_right_shift = 0; +}; + void* Init(TfLiteContext* context, const char* buffer, size_t length) { // This is a builtin op, so we don't use the contents in 'buffer', if any. // Instead, we allocate a new object to carry information from Prepare() to @@ -47,10 +52,19 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { return new OpData; } +void* LogSoftmaxInit(TfLiteContext* context, const char* buffer, + size_t length) { + return new LogSoftmaxOpData; +} + void Free(TfLiteContext* context, void* buffer) { delete reinterpret_cast(buffer); } +void LogSoftmaxFree(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -205,6 +219,34 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArrayCopy(input->dims)); } +TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { + LogSoftmaxOpData* data = reinterpret_cast(node->user_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + const TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 255); + TF_LITE_ENSURE_EQ(context, output->params.scale, 16.0 / 256); + + static const double kBeta = 1.0; + static const int kScaledDiffIntegerBits = 5; + tflite::PreprocessLogSoftmaxScalingExp( + kBeta, input->params.scale, kScaledDiffIntegerBits, + &data->input_multiplier, &data->input_left_shift, + &data->reverse_scaling_divisor, &data->reverse_scaling_right_shift); + data->reverse_scaling_right_shift *= -1; + data->diff_min = -1.0 * tflite::CalculateInputRadius( + kScaledDiffIntegerBits, data->input_left_shift); + } + + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -509,6 +551,8 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { } TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { + const LogSoftmaxOpData* data = + reinterpret_cast(node->user_data); const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); switch (input->type) { @@ -517,6 +561,14 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { GetTensorData(input), GetTensorShape(input), GetTensorData(output), GetTensorShape(output)); return kTfLiteOk; + case kTfLiteUInt8: + optimized_ops::LogSoftmax( + GetTensorData(input), GetTensorShape(input), + data->input_multiplier, data->input_left_shift, + data->reverse_scaling_divisor, data->reverse_scaling_right_shift, + data->diff_min, GetTensorData(output), + GetTensorShape(output)); + return kTfLiteOk; default: context->ReportError(context, "Only float32 supported currently., got %d", input->type); @@ -590,9 +642,9 @@ TfLiteRegistration* Register_SOFTMAX() { } TfLiteRegistration* Register_LOG_SOFTMAX() { - static TfLiteRegistration r = {activations::Init, activations::Free, - activations::GenericPrepare, - activations::LogSoftmaxEval}; + static TfLiteRegistration r = { + activations::LogSoftmaxInit, activations::LogSoftmaxFree, + activations::LogSoftmaxPrepare, activations::LogSoftmaxEval}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc index 083cdf78d76..e577e3a762b 100644 --- a/tensorflow/contrib/lite/kernels/activations_test.cc +++ b/tensorflow/contrib/lite/kernels/activations_test.cc @@ -471,6 +471,28 @@ TEST(FloatActivationsOpTest, LogSoftmax) { }))); } +TEST(QuantizedActivationsOpTest, LogSoftmax) { + const float kLogSoftmaxQuantizedTolerance = 16 / 256.0; + QuantizedActivationsOpModel m( + BuiltinOperator_LOG_SOFTMAX, + /*input=*/{TensorType_UINT8, {2, 4}, -10, 10}, + /*output=*/{TensorType_UINT8, {}, 0, 0, 16. / 256, 255}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + -4.14297, -10.14297, -2.14297, -.142971, // + -7.00104, -12.00104, -.00104087, -9.00104, // + }, + kLogSoftmaxQuantizedTolerance))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({189, 93, 221, 253, 142, 63, 255, 111})); +} + class PReluOpModel : public SingleOpModel { public: PReluOpModel(const TensorData& input, const TensorData& alpha) { diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index 04c0263b789..50fe5c2e042 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -334,18 +334,31 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, auto filter_offset = -filter->params.zero_point; auto output_offset = output->params.zero_point; - switch (kernel_type) { + KernelType effective_kernel_type; + if ((kernel_type == kMultithreadOptimized || + kernel_type == kCblasOptimized) && + (params->dilation_width_factor != 1 || + params->dilation_height_factor != 1)) { + // kMultithreadOptimized and kCblasOptimized do not support dilation. + // Therefore, fallback to optimized. + effective_kernel_type = kGenericOptimized; + } else { + effective_kernel_type = kernel_type; + } + + switch (effective_kernel_type) { case kReference: reference_ops::Conv( GetTensorData(input), GetTensorDims(input), input_offset, GetTensorData(filter), GetTensorDims(filter), filter_offset, GetTensorData(bias), GetTensorDims(bias), - params->stride_width, params->stride_height, data->padding.width, - data->padding.height, output_offset, data->output_multiplier, - data->output_shift, data->output_activation_min, - data->output_activation_max, GetTensorData(output), - GetTensorDims(output), GetTensorData(im2col), - GetTensorDims(im2col), gemm_context); + params->stride_width, params->stride_height, + params->dilation_width_factor, params->dilation_height_factor, + data->padding.width, data->padding.height, output_offset, + data->output_multiplier, data->output_shift, + data->output_activation_min, data->output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col), gemm_context); break; case kGenericOptimized: case kMultithreadOptimized: @@ -355,12 +368,13 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, GetTensorData(input), GetTensorDims(input), input_offset, GetTensorData(filter), GetTensorDims(filter), filter_offset, GetTensorData(bias), GetTensorDims(bias), - params->stride_width, params->stride_height, data->padding.width, - data->padding.height, output_offset, data->output_multiplier, - data->output_shift, data->output_activation_min, - data->output_activation_max, GetTensorData(output), - GetTensorDims(output), GetTensorData(im2col), - GetTensorDims(im2col), gemm_context); + params->stride_width, params->stride_height, + params->dilation_width_factor, params->dilation_height_factor, + data->padding.width, data->padding.height, output_offset, + data->output_multiplier, data->output_shift, + data->output_activation_min, data->output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col), gemm_context); break; } } @@ -374,10 +388,10 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, CalculateActivationRange(params->activation, &output_activation_min, &output_activation_max); KernelType effective_kernel_type; - if (((kernel_type == kMultithreadOptimized) || - (kernel_type == kCblasOptimized)) && - ((params->dilation_width_factor != 1) || - (params->dilation_height_factor != 1))) { + if ((kernel_type == kMultithreadOptimized || + kernel_type == kCblasOptimized) && + (params->dilation_width_factor != 1 || + params->dilation_height_factor != 1)) { // kMultithreadOptimized and kCblasOptimized do not support dilation. // Therefore, fallback to optimized. effective_kernel_type = kGenericOptimized; diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc index 24633c2fd7c..98152043c99 100644 --- a/tensorflow/contrib/lite/kernels/conv_test.cc +++ b/tensorflow/contrib/lite/kernels/conv_test.cc @@ -370,6 +370,65 @@ TEST_P(ConvolutionOpTest, HandCalculatedValidFloat32) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({312, 357})); } +TEST_P(ConvolutionOpTest, SimpleTestFloatWithDilation) { + const int depth = 1; + const int image_width = 9; + const int image_height = 9; + const int image_batch_count = 1; + const int filter_size = 3; + const int filter_count = 1; + const int stride_width = 1; + const int stride_height = 1; + const int dilation_width_factor = 3; + const int dilation_height_factor = 3; + const Padding padding = Padding_VALID; + ConvolutionOpModel m( + GetRegistration(), + {TensorType_FLOAT32, + {image_batch_count, image_height, image_width, depth}}, + {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, + {TensorType_FLOAT32, {}}, stride_width, stride_height, padding, + ActivationFunctionType_NONE, dilation_width_factor, + dilation_height_factor); + + // The image matrix is: + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // clang-format off + m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0}); + // clang-format on + // The filter matrix is: + // | 1 | 2 | 3 | + // | 4 | 5 | 6 | + // | 7 | 8 | 9 | + m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9}); + // No bias for this test. + m.SetBias({0}); + m.Invoke(); + + // Since the dilation rate is 3 this will reduce the size of the output from + // 10x10 to 3x3 of all 5s. Specifically: + // | 5 | 5 | 5 | + // | 5 | 5 | 5 | + // | 5 | 5 | 5 | + EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5})); +} + class QuantizedConvolutionOpModel : public BaseConvolutionOpModel { public: using BaseConvolutionOpModel::BaseConvolutionOpModel; @@ -500,6 +559,71 @@ TEST_P(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) { })); } +TEST_P(ConvolutionOpTest, SimpleTestQuantizedWithDilation) { + const int depth = 1; + const int image_width = 9; + const int image_height = 9; + const int image_batch_count = 1; + const int filter_size = 3; + const int filter_count = 1; + const int stride_width = 1; + const int stride_height = 1; + const int dilation_width_factor = 3; + const int dilation_height_factor = 3; + const Padding padding = Padding_VALID; + QuantizedConvolutionOpModel m( + GetRegistration(), + {TensorType_UINT8, + {image_batch_count, image_height, image_width, depth}, + 0, + 255}, + {TensorType_UINT8, + {depth, filter_size, filter_size, filter_count}, + 0, + 255}, + {TensorType_UINT8, {}, 0, 255}, stride_width, stride_height, padding, + ActivationFunctionType_NONE, dilation_width_factor, + dilation_height_factor); + + // The image matrix is: + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // clang-format off + m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0}); + // clang-format on + // The filter matrix is: + // | 1 | 2 | 3 | + // | 4 | 5 | 6 | + // | 7 | 8 | 9 | + m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9}); + // No bias for this test. + m.SetBias({0}); + m.Invoke(); + + // Since the dilation rate is 3 this will reduce the size of the output from + // 10x10 to 3x3 of all 5s. Specifically: + // | 5 | 5 | 5 | + // | 5 | 5 | 5 | + // | 5 | 5 | 5 | + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5})); +} + INSTANTIATE_TEST_CASE_P( ConvolutionOpTest, ConvolutionOpTest, ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap))); diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index 87155e4ba42..a97db6c6b25 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -539,7 +539,10 @@ cc_test( cc_test( name = "depthwiseconv_quantized_test", srcs = ["depthwiseconv_quantized_test.cc"], - tags = ["no_oss"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":optimized_base", ":reference_base", diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h index d5503073a7c..7f0676be274 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h @@ -30,11 +30,6 @@ namespace optimized_ops { using reference_ops::Relu1; using reference_ops::Relu6; -inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) { - return RuntimeShape( - {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]}); -} - template void L2Normalization(const float* input_data, const Dims<4>& input_dims, float* output_data, const Dims<4>& output_dims) { @@ -294,6 +289,37 @@ void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, output_data); } +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + BroadcastMul4DSlow( + input1_data, input1_dims, input1_offset, input2_data, input2_dims, + input2_offset, output_offset, output_multiplier, + // This legacy version switches the sign of the output shift. + kReverseShift * output_shift, + // (Break to highlight preceding line.) + output_activation_min, output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + BroadcastMul(input1_data, input1_dims, input1_offset, input2_data, + input2_dims, input2_offset, output_offset, output_multiplier, + output_shift, output_activation_min, output_activation_max, + output_data, output_dims); +} + inline void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride_width, int stride_height, int pad_width, int pad_height, int kwidth, int kheight, diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index b8707897723..2d172315da5 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -47,6 +47,7 @@ using reference_ops::BroadcastGreater; using reference_ops::BroadcastGreaterEqual; using reference_ops::BroadcastLess; using reference_ops::BroadcastLessEqual; +using reference_ops::BroadcastMul4DSlow; using reference_ops::BroadcastSub4DSlow; using reference_ops::Concatenation; using reference_ops::DepthConcatenation; @@ -75,6 +76,11 @@ using reference_ops::Transpose; // Used mainly to convert from old-style shifts (right) to new-style (left). static constexpr int kReverseShift = -1; +inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) { + return RuntimeShape( + {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]}); +} + // Make a local VectorMap typedef allowing to map a float array // as a Eigen vector expression. The std::conditional here is to // construct the suitable Eigen type for the constness of the @@ -1978,12 +1984,12 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, int32 input_offset, const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset, const int32* bias_data, const Dims<4>& bias_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims, uint8* im2col_data, - const Dims<4>& im2col_dims, + int stride_width, int stride_height, int dilation_width_factor, + int dilation_height_factor, int pad_width, int pad_height, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims, + uint8* im2col_data, const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) { gemmlowp::ScopedProfilingLabel label("Conv/8bit"); @@ -1995,9 +2001,22 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, const Dims<4>* gemm_input_dims = nullptr; const int filter_width = ArraySize(filter_dims, 1); const int filter_height = ArraySize(filter_dims, 2); + const bool need_dilated_im2col = + dilation_width_factor != 1 || dilation_height_factor != 1; const bool need_im2col = stride_width != 1 || stride_height != 1 || filter_width != 1 || filter_height != 1; - if (need_im2col) { + if (need_dilated_im2col) { + TFLITE_DCHECK(im2col_data); + const int input_zero_point = -input_offset; + TFLITE_DCHECK_GE(input_zero_point, 0); + TFLITE_DCHECK_LE(input_zero_point, 255); + DilatedIm2col(input_data, input_dims, filter_dims, stride_width, + stride_height, dilation_width_factor, dilation_height_factor, + pad_width, pad_height, output_dims, input_zero_point, + im2col_data); + gemm_input_data = im2col_data; + gemm_input_dims = &im2col_dims; + } else if (need_im2col) { TFLITE_DCHECK(im2col_data); const int input_zero_point = -input_offset; TFLITE_DCHECK_GE(input_zero_point, 0); @@ -2053,6 +2072,24 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, input_offset, output_pipeline); } +inline void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, + gemmlowp::GemmContext* gemm_context) { + Conv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1, + pad_width, pad_height, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims, gemm_context); +} + // legacy, for compatibility with old checked-in code template inline void Conv(const uint8* input_data, const Dims<4>& input_dims, @@ -2904,66 +2941,128 @@ void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, output_dims); } -inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, - int32 input1_offset, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit"); - - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); - - // In Tensorflow, the dimensions are canonically named (batch_number, row, - // col, channel), with extents (batches, height, width, depth), with the - // trailing dimension changing most rapidly (channels has the smallest stride, - // typically 1 element). - // - // In generated C code, we store arrays with the dimensions reversed. The - // first dimension has smallest stride. - // - // We name our variables by their Tensorflow convention, but generate C code - // nesting loops such that the innermost loop has the smallest stride for the - // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - const int32 input1_val = - input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; - const int32 input2_val = - input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; - const int32 unclamped_result = - output_offset + MultiplyByQuantizedMultiplierSmallerThanOneExp( - input1_val * input2_val, output_multiplier, - kReverseShift * output_shift); - const int32 clamped_output = - std::min(output_activation_max, - std::max(output_activation_min, unclamped_result)); - output_data[Offset(output_dims, c, x, y, b)] = - static_cast(clamped_output); - } - } - } +// Element-wise mul that can often be used for inner loop of broadcast Mul as +// well as the non-broadcast Mul. +inline void MulElementwise(int size, const ArithmeticParams& params, + const uint8* input1_data, const uint8* input2_data, + uint8* output_data) { + for (int i = 0; i < size; ++i) { + const int32 input1_val = params.input1_offset + input1_data[i]; + const int32 input2_val = params.input2_offset + input2_data[i]; + const int32 unclamped_result = + params.output_offset + + MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val, + params.output_multiplier, + params.output_shift); + const int32 clamped_output = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, unclamped_result)); + output_data[i] = static_cast(clamped_output); } } -// legacy, for compatibility with old checked-in code -template -inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, - int32 input1_offset, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - BroadcastMul(input1_data, input1_dims, input1_offset, input2_data, - input2_dims, input2_offset, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_data, output_dims); +// Broadcast mul that can often be used for inner loop of broadcast Mul. +inline void MulSimpleBroadcast(int size, const ArithmeticParams& params, + const uint8 broadcast_value, + const uint8* input2_data, uint8* output_data) { + const int32 input1_val = params.input1_offset + broadcast_value; + + for (int i = 0; i < size; ++i) { + const int32 input2_val = params.input2_offset + input2_data[i]; + const int32 unclamped_result = + params.output_offset + + MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val, + params.output_multiplier, + params.output_shift); + const int32 clamped_output = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, unclamped_result)); + output_data[i] = static_cast(clamped_output); + } +} + +inline void Mul(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const uint8* input1_data, + const RuntimeShape& input2_shape, const uint8* input2_data, + const RuntimeShape& output_shape, uint8* output_data) { + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + gemmlowp::ScopedProfilingLabel label("Mul/8bit"); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); + + MulElementwise(flat_size, params, input1_data, input2_data, output_data); +} + +inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params, + const RuntimeShape& unswitched_input1_shape, + const uint8* unswitched_input1_data, + const RuntimeShape& unswitched_input2_shape, + const uint8* unswitched_input2_data, + const RuntimeShape& output_shape, + uint8* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastMulFivefold/8bit"); + + ArithmeticParams switched_params = unswitched_params; + switched_params.input1_offset = unswitched_params.input2_offset; + switched_params.input2_offset = unswitched_params.input1_offset; + + const bool use_unswitched = + unswitched_params.broadcast_category == + tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast; + + const ArithmeticParams& params = + use_unswitched ? unswitched_params : switched_params; + const uint8* input1_data = + use_unswitched ? unswitched_input1_data : unswitched_input2_data; + const uint8* input2_data = + use_unswitched ? unswitched_input2_data : unswitched_input1_data; + + // Fivefold nested loops. The second input resets its position for each + // iteration of the second loop. The first input resets its position at the + // beginning of the fourth loop. The innermost loop is an elementwise Mul of + // sections of the arrays. + uint8* output_data_ptr = output_data; + const uint8* input1_data_ptr = input1_data; + const uint8* input2_data_reset = input2_data; + int y0 = params.broadcast_shape[0]; + int y1 = params.broadcast_shape[1]; + int y2 = params.broadcast_shape[2]; + int y3 = params.broadcast_shape[3]; + int y4 = params.broadcast_shape[4]; + if (y4 > 1) { + for (int i0 = 0; i0 < y0; ++i0) { + const uint8* input2_data_ptr; + for (int i1 = 0; i1 < y1; ++i1) { + input2_data_ptr = input2_data_reset; + for (int i2 = 0; i2 < y2; ++i2) { + for (int i3 = 0; i3 < y3; ++i3) { + MulElementwise(y4, params, input1_data_ptr, input2_data_ptr, + output_data_ptr); + input2_data_ptr += y4; + output_data_ptr += y4; + } + input1_data_ptr += y4; + } + } + input2_data_reset = input2_data_ptr; + } + } else { + for (int i0 = 0; i0 < y0; ++i0) { + const uint8* input2_data_ptr; + for (int i1 = 0; i1 < y1; ++i1) { + input2_data_ptr = input2_data_reset; + for (int i2 = 0; i2 < y2; ++i2) { + MulSimpleBroadcast(y3, params, *input1_data_ptr, input2_data_ptr, + output_data_ptr); + input2_data_ptr += y3; + output_data_ptr += y3; + ++input1_data_ptr; + } + } + input2_data_reset = input2_data_ptr; + } + } } // TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary @@ -5383,31 +5482,53 @@ void TypedMemset(void* ptr, T value, size_t num) { } } -template -inline void PadV2(const T* input_data, const Dims<4>& input_dims, - const std::vector& left_paddings, - const std::vector& right_paddings, T* output_data, - const Dims<4>& output_dims, const T pad_value) { +// There are two versions of pad: Pad and PadV2. In PadV2 there is a second +// scalar input that provides the padding value. Therefore pad_value_ptr can be +// equivalent to a simple input1_data. For Pad, it should point to a zero +// value. +// +// Note that two typenames are required, so that T=P=int32 is considered a +// specialization distinct from P=int32. +template +inline void PadImpl(const tflite::PadParams& op_params, + const RuntimeShape& input_shape, const T* input_data, + const P* pad_value_ptr, const RuntimeShape& output_shape, + T* output_data) { gemmlowp::ScopedProfilingLabel label("Pad"); - TFLITE_DCHECK_EQ(left_paddings.size(), 4); - TFLITE_DCHECK_EQ(right_paddings.size(), 4); + RuntimeShape ext_input_shape = RuntimeShape::ExtendedShape(4, input_shape); + RuntimeShape ext_output_shape = RuntimeShape::ExtendedShape(4, output_shape); + TFLITE_DCHECK_LE(op_params.left_padding_count, 4); + TFLITE_DCHECK_LE(op_params.right_padding_count, 4); - const int output_batch = ArraySize(output_dims, 3); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int output_depth = ArraySize(output_dims, 0); + // Runtime calls are currently fixed at 4 dimensions. Copy inputs so + // we can pad them to 4 dims (yes, we are "padding the padding"). + std::vector left_padding_copy(4, 0); + for (int i = 0; i < op_params.left_padding_count; ++i) { + left_padding_copy[i] = op_params.left_padding[i]; + } + std::vector right_padding_copy(4, 0); + for (int i = 0; i < op_params.right_padding_count; ++i) { + right_padding_copy[i] = op_params.right_padding[i]; + } - const int left_b_padding = left_paddings[3]; - const int left_h_padding = left_paddings[2]; - const int left_w_padding = left_paddings[1]; - const int left_d_padding = left_paddings[0]; + const int output_batch = ext_output_shape.Dims(0); + const int output_height = ext_output_shape.Dims(1); + const int output_width = ext_output_shape.Dims(2); + const int output_depth = ext_output_shape.Dims(3); - const int right_b_padding = right_paddings[3]; - const int right_h_padding = right_paddings[2]; - const int right_w_padding = right_paddings[1]; - const int right_d_padding = right_paddings[0]; + const int left_b_padding = left_padding_copy[0]; + const int left_h_padding = left_padding_copy[1]; + const int left_w_padding = left_padding_copy[2]; + const int left_d_padding = left_padding_copy[3]; - const int input_depth = ArraySize(input_dims, 0); + const int right_b_padding = right_padding_copy[0]; + const int right_h_padding = right_padding_copy[1]; + const int right_w_padding = right_padding_copy[2]; + const int right_d_padding = right_padding_copy[3]; + + const int input_depth = ext_input_shape.Dims(3); + // const T pad_value = ExtractFloatOrInt(op_params.pad_value); + const T pad_value = *pad_value_ptr; if (left_b_padding != 0) { TypedMemset( @@ -5417,61 +5538,113 @@ inline void PadV2(const T* input_data, const Dims<4>& input_dims, for (int out_b = left_b_padding; out_b < output_batch - right_b_padding; ++out_b) { if (left_h_padding != 0) { - TypedMemset(output_data + Offset(output_dims, 0, 0, 0, out_b), + TypedMemset(output_data + Offset(ext_output_shape, out_b, 0, 0, 0), pad_value, left_h_padding * output_width * output_depth); } for (int out_h = left_h_padding; out_h < output_height - right_h_padding; ++out_h) { if (left_w_padding != 0) { - TypedMemset(output_data + Offset(output_dims, 0, 0, out_h, out_b), - pad_value, left_w_padding * output_depth); + TypedMemset( + output_data + Offset(ext_output_shape, out_b, out_h, 0, 0), + pad_value, left_w_padding * output_depth); } for (int out_w = left_w_padding; out_w < output_width - right_w_padding; ++out_w) { if (left_d_padding != 0) { TypedMemset( - output_data + Offset(output_dims, 0, out_w, out_h, out_b), + output_data + Offset(ext_output_shape, out_b, out_h, out_w, 0), pad_value, left_d_padding); } T* out = output_data + - Offset(output_dims, left_d_padding, out_w, out_h, out_b); - const T* in = - input_data + Offset(input_dims, 0, out_w - left_w_padding, - out_h - left_h_padding, out_b - left_b_padding); + Offset(ext_output_shape, out_b, out_h, out_w, left_d_padding); + const T* in = input_data + + Offset(ext_input_shape, out_b - left_b_padding, + out_h - left_h_padding, out_w - left_w_padding, 0); memcpy(out, in, input_depth * sizeof(T)); if (right_d_padding != 0) { TypedMemset( - output_data + Offset(output_dims, output_depth - right_d_padding, - out_w, out_h, out_b), + output_data + Offset(ext_output_shape, out_b, out_h, out_w, + output_depth - right_d_padding), pad_value, right_d_padding); } } if (right_w_padding != 0) { - TypedMemset( - output_data + Offset(output_dims, 0, output_width - right_w_padding, - out_h, out_b), - pad_value, right_w_padding * output_depth); + TypedMemset(output_data + Offset(ext_output_shape, out_b, out_h, + output_width - right_w_padding, 0), + pad_value, right_w_padding * output_depth); } } if (right_h_padding != 0) { TypedMemset( - output_data + - Offset(output_dims, 0, 0, output_height - right_h_padding, out_b), + output_data + Offset(ext_output_shape, out_b, + output_height - right_h_padding, 0, 0), pad_value, right_h_padding * output_width * output_depth); } } if (right_b_padding != 0) { TypedMemset( output_data + - Offset(output_dims, 0, 0, 0, output_batch - right_b_padding), + Offset(ext_output_shape, output_batch - right_b_padding, 0, 0, 0), pad_value, right_b_padding * output_height * output_width * output_depth); } } -// Legacy Pad() method that casts an int32_t to T before padding. +template +inline void Pad(const tflite::PadParams& op_params, + const RuntimeShape& input_shape, const T* input_data, + const P* pad_value_ptr, const RuntimeShape& output_shape, + T* output_data) { + PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape, + output_data); +} + +// The second (pad-value) input can be int32 when, say, the first is uint8. +template +inline void Pad(const tflite::PadParams& op_params, + const RuntimeShape& input_shape, const T* input_data, + const int32* pad_value_ptr, const RuntimeShape& output_shape, + T* output_data) { + const T converted_pad_value = static_cast(*pad_value_ptr); + PadImpl(op_params, input_shape, input_data, &converted_pad_value, + output_shape, output_data); +} + +// This version avoids conflicting template matching. +template <> +inline void Pad(const tflite::PadParams& op_params, + const RuntimeShape& input_shape, const int32* input_data, + const int32* pad_value_ptr, const RuntimeShape& output_shape, + int32* output_data) { + PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape, + output_data); +} + +// Legacy signature, function covered both Pad and PadV2. +template +inline void PadV2(const T* input_data, const Dims<4>& input_dims, + const std::vector& left_paddings, + const std::vector& right_paddings, T* output_data, + const Dims<4>& output_dims, const T pad_value) { + TFLITE_DCHECK_EQ(left_paddings.size(), 4); + TFLITE_DCHECK_EQ(right_paddings.size(), 4); + tflite::PadParams op_params; + op_params.left_padding_count = 4; + op_params.right_padding_count = 4; + for (int i = 0; i < 4; ++i) { + op_params.left_padding[i] = left_paddings[3 - i]; + op_params.right_padding[i] = right_paddings[3 - i]; + } + // SetFloatOrInt(pad_value, &op_params.pad_value); + const T pad_value_copy = pad_value; + + Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy, + DimsToShape(output_dims), output_data); +} + +// Old Pad that calls legacy PadV2. template inline void Pad(const T* input_data, const Dims<4>& input_dims, const std::vector& left_paddings, @@ -5482,34 +5655,45 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims, output_dims, converted_pad_value); } +// Old Pad that only padded with 0. template inline void Pad(const T* input_data, const Dims<4>& input_dims, const std::vector& left_paddings, const std::vector& right_paddings, T* output_data, const Dims<4>& output_dims) { - Pad(input_data, input_dims, left_paddings, right_paddings, output_data, - output_dims, 0); + const T pad_value = static_cast(0); + PadV2(input_data, input_dims, left_paddings, right_paddings, output_data, + output_dims, pad_value); } template -inline void Slice(const T* input_data, const Dims<4>& input_dims, - const std::vector& begin, const std::vector& size, - T* output_data, const Dims<4>& output_dims) { - // TODO(dkalenichenko): This op only supports 4D tensors. - TFLITE_DCHECK_EQ(begin.size(), 4); - TFLITE_DCHECK_EQ(size.size(), 4); - const int start_b = begin[3]; - const int stop_b = - size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3]; - const int start_h = begin[2]; - const int stop_h = - size[2] == -1 ? input_dims.sizes[2] - start_h : start_h + size[2]; - const int start_w = begin[1]; - const int stop_w = - size[1] == -1 ? input_dims.sizes[1] - start_w : start_w + size[1]; - const int start_d = begin[0]; - const int stop_d = - size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0]; +inline void Slice(const tflite::SliceParams& op_params, + const RuntimeShape& input_shape, const T* input_data, + const RuntimeShape& output_shape, T* output_data) { + gemmlowp::ScopedProfilingLabel label("Slice"); + RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape); + // TODO(dkalenichenko): This op only supports 4D tensors or smaller. + TFLITE_DCHECK_LE(op_params.begin_count, 4); + TFLITE_DCHECK_LE(op_params.size_count, 4); + const int begin_count = op_params.begin_count; + const int size_count = op_params.size_count; + // We front-pad the begin and size vectors. + const int start_b = 4 - begin_count > 0 ? 0 : op_params.begin[0]; + const int stop_b = (4 - size_count > 0 || op_params.size[0] == -1) + ? ext_shape.Dims(0) - start_b + : start_b + op_params.size[0]; + const int start_h = begin_count < 3 ? 0 : op_params.begin[begin_count - 3]; + const int stop_h = (size_count < 3 || op_params.size[size_count - 3] == -1) + ? ext_shape.Dims(1) - start_h + : start_h + op_params.size[size_count - 3]; + const int start_w = begin_count < 2 ? 0 : op_params.begin[begin_count - 2]; + const int stop_w = (size_count < 2 || op_params.size[size_count - 2] == -1) + ? ext_shape.Dims(2) - start_w + : start_w + op_params.size[size_count - 2]; + const int start_d = begin_count < 1 ? 0 : op_params.begin[begin_count - 1]; + const int stop_d = (size_count < 1 || op_params.size[size_count - 1] == -1) + ? ext_shape.Dims(3) - start_d + : start_d + op_params.size[size_count - 1]; T* out_ptr = output_data; for (int in_b = start_b; in_b < stop_b; ++in_b) { @@ -5517,7 +5701,7 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims, for (int in_w = start_w; in_w < stop_w; ++in_w) { const int len = stop_d - start_d; memcpy(out_ptr, - input_data + Offset(input_dims, start_d, in_w, in_h, in_b), + input_data + Offset(ext_shape, in_b, in_h, in_w, start_d), len * sizeof(T)); out_ptr += len; } @@ -5525,26 +5709,58 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims, } } +template +inline void Slice(const T* input_data, const Dims<4>& input_dims, + const std::vector& begin, const std::vector& size, + T* output_data, const Dims<4>& output_dims) { + tflite::SliceParams op_params; + op_params.begin_count = 4; + op_params.size_count = 4; + for (int i = 0; i < 4; ++i) { + op_params.begin[i] = begin[3 - i]; + op_params.size[i] = size[3 - i]; + } + + Slice(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + +template +void Minimum(const RuntimeShape& input1_shape, const T* input1_data, + const T* input2_data, const RuntimeShape& output_shape, + T* output_data) { + gemmlowp::ScopedProfilingLabel label("TensorFlowMinimum"); + auto input1_map = MapAsVector(input1_data, input1_shape); + auto output_map = MapAsVector(output_data, output_shape); + auto min_value = input2_data[0]; + output_map.array() = input1_map.array().min(min_value); +} + +template +void Maximum(const RuntimeShape& input1_shape, const T* input1_data, + const T* input2_data, const RuntimeShape& output_shape, + T* output_data) { + gemmlowp::ScopedProfilingLabel label("TensorFlowMaximum"); + auto input1_map = MapAsVector(input1_data, input1_shape); + auto output_map = MapAsVector(output_data, output_shape); + auto max_value = input2_data[0]; + output_map.array() = input1_map.array().max(max_value); +} + template void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, T* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("TensorFlowMinimum"); - auto input1_map = MapAsVector(input1_data, input1_dims); - auto output_map = MapAsVector(output_data, output_dims); - auto min_value = input2_data[0]; - output_map.array() = input1_map.array().min(min_value); + Minimum(DimsToShape(input1_dims), input1_data, input2_data, + DimsToShape(output_dims), output_data); } template void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, T* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("TensorFlowMaximum"); - auto input1_map = MapAsVector(input1_data, input1_dims); - auto output_map = MapAsVector(output_data, output_dims); - auto max_value = input2_data[0]; - output_map.array() = input1_map.array().max(max_value); + Maximum(DimsToShape(input1_dims), input1_data, input2_data, + DimsToShape(output_dims), output_data); } template diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h index bcf5e4e4f65..b862ae38c7b 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h @@ -26,11 +26,6 @@ namespace tflite { namespace reference_ops { -inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) { - return RuntimeShape( - {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]}); -} - template void L2Normalization(const float* input_data, const Dims<4>& input_dims, float* output_data, const Dims<4>& output_dims) { @@ -316,6 +311,37 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims, DimsToShape(output_dims), output_data); } +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + BroadcastMul4DSlow( + input1_data, input1_dims, input1_offset, input2_data, input2_dims, + input2_offset, output_offset, output_multiplier, + // + kReverseShift * output_shift, + // + output_activation_min, output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + BroadcastMul(input1_data, input1_dims, input1_offset, input2_data, + input2_dims, input2_offset, output_offset, output_multiplier, + output_shift, output_activation_min, output_activation_max, + output_data, output_dims); +} + // legacy, for compatibility with old checked-in code template void AveragePool(const float* input_data, const Dims<4>& input_dims, diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index f4176e474e7..cb254f36cc1 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -105,6 +105,11 @@ namespace reference_ops { // Used mainly to convert from old-style shifts (right) to new-style (left). static constexpr int kReverseShift = -1; +inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) { + return RuntimeShape( + {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]}); +} + template int CountLeadingZeros(T integer_input) { static_assert(std::is_unsigned::value, @@ -271,12 +276,12 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, int32 input_offset, const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset, const int32* bias_data, const Dims<4>& bias_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims, uint8* im2col_data, - const Dims<4>& im2col_dims, + int stride_width, int stride_height, int dilation_width_factor, + int dilation_height_factor, int pad_width, int pad_height, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims, + uint8* im2col_data, const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) { (void)im2col_data; // only used in optimized code. (void)im2col_dims; // only used in optimized code. @@ -302,8 +307,9 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, for (int filter_y = 0; filter_y < filter_height; ++filter_y) { for (int filter_x = 0; filter_x < filter_width; ++filter_x) { for (int in_channel = 0; in_channel < input_depth; ++in_channel) { - const int in_x = in_x_origin + filter_x; - const int in_y = in_y_origin + filter_y; + const int in_x = in_x_origin + dilation_width_factor * filter_x; + const int in_y = + in_y_origin + dilation_height_factor * filter_y; // If the location is outside the bounds of the input image, // use zero as a default value. if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && @@ -335,6 +341,24 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, } } +inline void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, + gemmlowp::GemmContext* gemm_context) { + Conv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1, + pad_width, pad_height, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims, gemm_context); +} + // legacy, for compatibility with old checked-in code template inline void Conv(const uint8* input_data, const Dims<4>& input_dims, @@ -1374,13 +1398,143 @@ void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, output_dims); } -inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, - int32 input1_offset, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { +// Element-wise mul that can often be used for inner loop of broadcast Mul as +// well as the non-broadcast Mul. +inline void MulElementwise(int size, const ArithmeticParams& params, + const uint8* input1_data, const uint8* input2_data, + uint8* output_data) { + for (int i = 0; i < size; ++i) { + const int32 input1_val = params.input1_offset + input1_data[i]; + const int32 input2_val = params.input2_offset + input2_data[i]; + const int32 unclamped_result = + params.output_offset + + MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val, + params.output_multiplier, + params.output_shift); + const int32 clamped_output = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, unclamped_result)); + output_data[i] = static_cast(clamped_output); + } +} + +inline void Mul(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const uint8* input1_data, + const RuntimeShape& input2_shape, const uint8* input2_data, + const RuntimeShape& output_shape, uint8* output_data) { + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + gemmlowp::ScopedProfilingLabel label("Mul/8bit"); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); + + MulElementwise(flat_size, params, input1_data, input2_data, output_data); +} + +inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params, + const RuntimeShape& unswitched_input1_shape, + const uint8* unswitched_input1_data, + const RuntimeShape& unswitched_input2_shape, + const uint8* unswitched_input2_data, + const RuntimeShape& output_shape, + uint8* output_data) { + ArithmeticParams switched_params = unswitched_params; + switched_params.input1_offset = unswitched_params.input2_offset; + switched_params.input2_offset = unswitched_params.input1_offset; + + const bool use_unswitched = + unswitched_params.broadcast_category == + tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast; + + const ArithmeticParams& params = + use_unswitched ? unswitched_params : switched_params; + const uint8* input1_data = + use_unswitched ? unswitched_input1_data : unswitched_input2_data; + const uint8* input2_data = + use_unswitched ? unswitched_input2_data : unswitched_input1_data; + + // Fivefold nested loops. The second input resets its position for each + // iteration of the second loop. The first input resets its position at the + // beginning of the fourth loop. The innermost loop is an elementwise Mul of + // sections of the arrays. + uint8* output_data_ptr = output_data; + const uint8* input1_data_ptr = input1_data; + const uint8* input2_data_reset = input2_data; + int y0 = params.broadcast_shape[0]; + int y1 = params.broadcast_shape[1]; + int y2 = params.broadcast_shape[2]; + int y3 = params.broadcast_shape[3]; + int y4 = params.broadcast_shape[4]; + for (int i0 = 0; i0 < y0; ++i0) { + const uint8* input2_data_ptr; + for (int i1 = 0; i1 < y1; ++i1) { + input2_data_ptr = input2_data_reset; + for (int i2 = 0; i2 < y2; ++i2) { + for (int i3 = 0; i3 < y3; ++i3) { + MulElementwise(y4, params, input1_data_ptr, input2_data_ptr, + output_data_ptr); + input2_data_ptr += y4; + output_data_ptr += y4; + } + input1_data_ptr += y4; + } + } + input2_data_reset = input2_data_ptr; + } +} + +inline void BroadcastMul4DSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const uint8* input1_data, + const RuntimeShape& input2_shape, + const uint8* input2_data, + const RuntimeShape& output_shape, + uint8* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastMul4DSlow/8bit"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, + &desc2); + RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); + + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { + const int32 input1_val = + params.input1_offset + + input1_data[SubscriptToIndex(desc1, b, y, x, c)]; + const int32 input2_val = + params.input2_offset + + input2_data[SubscriptToIndex(desc2, b, y, x, c)]; + const int32 unclamped_result = + params.output_offset + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + input1_val * input2_val, params.output_multiplier, + params.output_shift); + const int32 clamped_output = std::min( + params.quantized_activation_max, + std::max(params.quantized_activation_min, unclamped_result)); + output_data[Offset(extended_output_shape, b, y, x, c)] = + static_cast(clamped_output); + } + } + } + } +} + +// Transitional version that will be moved shortly to legacy_reference_ops, as +// part of RuntimeShape revisions. +inline void BroadcastMul4DSlow(const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit"); NdArrayDesc<4> desc1; @@ -1407,9 +1561,9 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, const int32 input2_val = input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; const int32 unclamped_result = - output_offset + MultiplyByQuantizedMultiplierSmallerThanOneExp( - input1_val * input2_val, output_multiplier, - kReverseShift * output_shift); + output_offset + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + input1_val * input2_val, output_multiplier, output_shift); const int32 clamped_output = std::min(output_activation_max, std::max(output_activation_min, unclamped_result)); @@ -1464,21 +1618,6 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, } } -// legacy, for compatibility with old checked-in code -template -inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, - int32 input1_offset, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - BroadcastMul(input1_data, input1_dims, input1_offset, input2_data, - input2_dims, input2_offset, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_data, output_dims); -} - // TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then @@ -3370,28 +3509,50 @@ inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, } } -template -inline void PadV2(const T* input_data, const Dims<4>& input_dims, - const std::vector& left_paddings, - const std::vector& right_paddings, T* output_data, - const Dims<4>& output_dims, const T pad_value) { - TFLITE_DCHECK_EQ(left_paddings.size(), 4); - TFLITE_DCHECK_EQ(right_paddings.size(), 4); +// There are two versions of pad: Pad and PadV2. In PadV2 there is a second +// scalar input that provides the padding value. Therefore pad_value_ptr can be +// equivalent to a simple input1_data. For Pad, it should point to a zero +// value. +// +// Note that two typenames are required, so that T=P=int32 is considered a +// specialization distinct from P=int32. +template +inline void PadImpl(const tflite::PadParams& op_params, + const RuntimeShape& input_shape, const T* input_data, + const P* pad_value_ptr, const RuntimeShape& output_shape, + T* output_data) { + RuntimeShape ext_input_shape = RuntimeShape::ExtendedShape(4, input_shape); + RuntimeShape ext_output_shape = RuntimeShape::ExtendedShape(4, output_shape); + TFLITE_DCHECK_LE(op_params.left_padding_count, 4); + TFLITE_DCHECK_LE(op_params.right_padding_count, 4); - const int output_batch = ArraySize(output_dims, 3); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int output_depth = ArraySize(output_dims, 0); + // Runtime calls are currently fixed at 4 dimensions. Copy inputs so + // we can pad them to 4 dims (yes, we are "padding the padding"). + std::vector left_padding_copy(4, 0); + for (int i = 0; i < op_params.left_padding_count; ++i) { + left_padding_copy[i] = op_params.left_padding[i]; + } + std::vector right_padding_copy(4, 0); + for (int i = 0; i < op_params.right_padding_count; ++i) { + right_padding_copy[i] = op_params.right_padding[i]; + } - const int left_b_padding = left_paddings[3]; - const int left_h_padding = left_paddings[2]; - const int left_w_padding = left_paddings[1]; - const int left_d_padding = left_paddings[0]; + const int output_batch = ext_output_shape.Dims(0); + const int output_height = ext_output_shape.Dims(1); + const int output_width = ext_output_shape.Dims(2); + const int output_depth = ext_output_shape.Dims(3); - const int right_b_padding = right_paddings[3]; - const int right_h_padding = right_paddings[2]; - const int right_w_padding = right_paddings[1]; - const int right_d_padding = right_paddings[0]; + const int left_b_padding = left_padding_copy[0]; + const int left_h_padding = left_padding_copy[1]; + const int left_w_padding = left_padding_copy[2]; + const int left_d_padding = left_padding_copy[3]; + + const int right_b_padding = right_padding_copy[0]; + const int right_h_padding = right_padding_copy[1]; + const int right_w_padding = right_padding_copy[2]; + const int right_d_padding = right_padding_copy[3]; + + const T pad_value = *pad_value_ptr; const T* in_ptr = input_data; T* out_ptr = output_data; @@ -3417,7 +3578,59 @@ inline void PadV2(const T* input_data, const Dims<4>& input_dims, } } -// Legacy Pad() method that casts an int32_t to T before padding. +template +inline void Pad(const tflite::PadParams& op_params, + const RuntimeShape& input_shape, const T* input_data, + const P* pad_value_ptr, const RuntimeShape& output_shape, + T* output_data) { + PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape, + output_data); +} + +// The second (pad-value) input can be int32 when, say, the first is uint8. +template +inline void Pad(const tflite::PadParams& op_params, + const RuntimeShape& input_shape, const T* input_data, + const int32* pad_value_ptr, const RuntimeShape& output_shape, + T* output_data) { + const T converted_pad_value = static_cast(*pad_value_ptr); + PadImpl(op_params, input_shape, input_data, &converted_pad_value, + output_shape, output_data); +} + +// This version avoids conflicting template matching. +template <> +inline void Pad(const tflite::PadParams& op_params, + const RuntimeShape& input_shape, const int32* input_data, + const int32* pad_value_ptr, const RuntimeShape& output_shape, + int32* output_data) { + PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape, + output_data); +} + +// Legacy signature, function covered both Pad and PadV2. +template +inline void PadV2(const T* input_data, const Dims<4>& input_dims, + const std::vector& left_paddings, + const std::vector& right_paddings, T* output_data, + const Dims<4>& output_dims, const T pad_value) { + TFLITE_DCHECK_EQ(left_paddings.size(), 4); + TFLITE_DCHECK_EQ(right_paddings.size(), 4); + tflite::PadParams op_params; + op_params.left_padding_count = 4; + op_params.right_padding_count = 4; + for (int i = 0; i < 4; ++i) { + op_params.left_padding[i] = left_paddings[3 - i]; + op_params.right_padding[i] = right_paddings[3 - i]; + } + // SetFloatOrInt(pad_value, &op_params.pad_value); + const T pad_value_copy = pad_value; + + Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy, + DimsToShape(output_dims), output_data); +} + +// Old Pad that calls legacy PadV2. template inline void Pad(const T* input_data, const Dims<4>& input_dims, const std::vector& left_paddings, @@ -3428,13 +3641,15 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims, output_dims, converted_pad_value); } +// Old Pad that only padded with 0. template inline void Pad(const T* input_data, const Dims<4>& input_dims, const std::vector& left_paddings, const std::vector& right_paddings, T* output_data, const Dims<4>& output_dims) { - Pad(input_data, input_dims, left_paddings, right_paddings, output_data, - output_dims, 0); + const T pad_value = static_cast(0); + PadV2(input_data, input_dims, left_paddings, right_paddings, output_data, + output_dims, pad_value); } template @@ -3491,37 +3706,61 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, } template -inline void Slice(const T* input_data, const Dims<4>& input_dims, - const std::vector& begin, const std::vector& size, - T* output_data, const Dims<4>& output_dims) { - // TODO(dkalenichenko): This op only supports 4D tensors. - TFLITE_DCHECK_EQ(begin.size(), 4); - TFLITE_DCHECK_EQ(size.size(), 4); - const int start_b = begin[3]; - const int stop_b = - size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3]; - const int start_h = begin[2]; - const int stop_h = - size[2] == -1 ? input_dims.sizes[2] - start_h : start_h + size[2]; - const int start_w = begin[1]; - const int stop_w = - size[1] == -1 ? input_dims.sizes[1] - start_w : start_w + size[1]; - const int start_d = begin[0]; - const int stop_d = - size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0]; +inline void Slice(const tflite::SliceParams& op_params, + const RuntimeShape& input_shape, const T* input_data, + const RuntimeShape& output_shape, T* output_data) { + RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape); + // TODO(dkalenichenko): This op only supports 4D tensors or smaller. + TFLITE_DCHECK_LE(op_params.begin_count, 4); + TFLITE_DCHECK_LE(op_params.size_count, 4); + const int begin_count = op_params.begin_count; + const int size_count = op_params.size_count; + // We front-pad the begin and size vectors. + const int start_b = 4 - begin_count > 0 ? 0 : op_params.begin[0]; + const int stop_b = (4 - size_count > 0 || op_params.size[0] == -1) + ? ext_shape.Dims(0) - start_b + : start_b + op_params.size[0]; + const int start_h = begin_count < 3 ? 0 : op_params.begin[begin_count - 3]; + const int stop_h = (size_count < 3 || op_params.size[size_count - 3] == -1) + ? ext_shape.Dims(1) - start_h + : start_h + op_params.size[size_count - 3]; + const int start_w = begin_count < 2 ? 0 : op_params.begin[begin_count - 2]; + const int stop_w = (size_count < 2 || op_params.size[size_count - 2] == -1) + ? ext_shape.Dims(2) - start_w + : start_w + op_params.size[size_count - 2]; + const int start_d = begin_count < 1 ? 0 : op_params.begin[begin_count - 1]; + const int stop_d = (size_count < 1 || op_params.size[size_count - 1] == -1) + ? ext_shape.Dims(3) - start_d + : start_d + op_params.size[size_count - 1]; T* out_ptr = output_data; for (int in_b = start_b; in_b < stop_b; ++in_b) { for (int in_h = start_h; in_h < stop_h; ++in_h) { for (int in_w = start_w; in_w < stop_w; ++in_w) { for (int in_d = start_d; in_d < stop_d; ++in_d) { - *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; + *out_ptr++ = input_data[Offset(ext_shape, in_b, in_h, in_w, in_d)]; } } } } } +template +inline void Slice(const T* input_data, const Dims<4>& input_dims, + const std::vector& begin, const std::vector& size, + T* output_data, const Dims<4>& output_dims) { + tflite::SliceParams op_params; + op_params.begin_count = 4; + op_params.size_count = 4; + for (int i = 0; i < 4; ++i) { + op_params.begin[i] = begin[3 - i]; + op_params.size[i] = size[3 - i]; + } + + Slice(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + template inline void Exp(const T* input_data, const size_t num_elements, T* output_data) { @@ -3790,10 +4029,10 @@ inline void Mean(const T* input_data, const Dims<4>& input_dims, } template -void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, T* output_data, - const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input1_dims); +void Minimum(const RuntimeShape& input1_shape, const T* input1_data, + const T* input2_data, const RuntimeShape& output_shape, + T* output_data) { + const int flat_size = MatchingFlatSize(input1_shape, output_shape); auto min_value = input2_data[0]; for (int i = 0; i < flat_size; i++) { @@ -3802,10 +4041,10 @@ void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims, } template -void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, T* output_data, - const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input1_dims); +void Maximum(const RuntimeShape& input1_shape, const T* input1_data, + const T* input2_data, const RuntimeShape& output_shape, + T* output_data) { + const int flat_size = MatchingFlatSize(input1_shape, output_shape); auto max_value = input2_data[0]; for (int i = 0; i < flat_size; i++) { @@ -3813,6 +4052,22 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, } } +template +void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, T* output_data, + const Dims<4>& output_dims) { + Minimum(DimsToShape(input1_dims), input1_data, input2_data, + DimsToShape(output_dims), output_data); +} + +template +void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, T* output_data, + const Dims<4>& output_dims) { + Maximum(DimsToShape(input1_dims), input1_data, input2_data, + DimsToShape(output_dims), output_data); +} + template void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, const Dims<4>& input2_dims, diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index c44698b677a..7b6838db53a 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -129,6 +129,13 @@ class RuntimeShape { } } + RuntimeShape(int shape_size, int32 value) : size_(0) { + Resize(shape_size); + for (int i = 0; i < shape_size; ++i) { + SetDim(i, value); + } + } + RuntimeShape(int dimensions_count, const int32* dims_data) : size_(0) { ReplaceWith(dimensions_count, dims_data); } @@ -237,7 +244,7 @@ class RuntimeShape { bool operator!=(const RuntimeShape& comp) const { return !((*this) == comp); } private: - // For use only by ExtendFrom(), written to guarantee (return-value) copy + // For use only by ExtendedShape(), written to guarantee (return-value) copy // elision in C++17. // This creates a shape padded to the desired size with the specified value. RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value) @@ -645,22 +652,6 @@ void ComputeStrides(Dims* dims) { } } -struct PoolParams { - FusedActivationFunctionType activation; - PaddingType padding_type; - PaddingValues padding_values; - int stride_height; - int stride_width; - int filter_height; - int filter_width; - // uint8, etc, activation params. - int32 quantized_activation_min; - int32 quantized_activation_max; - // float activation params. - float float_activation_min; - float float_activation_max; -}; - enum class BroadcastableOpCategory : uint8 { kNone, kNonBroadcast, // Matching input shapes. @@ -721,6 +712,37 @@ inline void SetActivationParams(int32 min, int32 max, params->quantized_activation_max = max; } +struct PadParams { + int8 left_padding_count; + int32 left_padding[4]; + int8 right_padding_count; + int32 right_padding[4]; + // FloatOrInt pad_value; +}; + +struct PoolParams { + FusedActivationFunctionType activation; + PaddingType padding_type; + PaddingValues padding_values; + int stride_height; + int stride_width; + int filter_height; + int filter_width; + // uint8, etc, activation params. + int32 quantized_activation_min; + int32 quantized_activation_max; + // float activation params. + float float_activation_min; + float float_activation_max; +}; + +struct SliceParams { + int8 begin_count; + int32 begin[4]; + int8 size_count; + int32 size[4]; +}; + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc index 349f3e67261..561e39cfc69 100644 --- a/tensorflow/contrib/lite/kernels/mul.cc +++ b/tensorflow/contrib/lite/kernels/mul.cc @@ -93,7 +93,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { input1->params.scale * input2->params.scale / output->params.scale; QuantizeMultiplierSmallerThanOneExp( real_multiplier, &data->output_multiplier, &data->output_shift); - data->output_shift *= -1; } return context->ResizeTensor(context, output, output_size); @@ -161,9 +160,9 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, // The quantized version of Mul doesn't support activations, so we // always use BroadcastMul. if (kernel_type == kReference) { - TF_LITE_MUL(reference_ops, BroadcastMul); + TF_LITE_MUL(reference_ops, BroadcastMul4DSlow); } else { - TF_LITE_MUL(optimized_ops, BroadcastMul); + TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow); } #undef TF_LITE_MUL } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 && diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 8d2c108116e..6159311910b 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -127,9 +127,11 @@ const TfLiteRegistration* BuiltinOpResolver::FindOp(tflite::BuiltinOperator op, const TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op, int version) const { - // Return the NULL Op for all ops whose name start with "Eager:", allowing + // Return the NULL Op for all ops whose name start with "Eager", allowing // the interpreter to delegate their execution. - if (string(op).find("Eager:") == 0) { + // TODO(ycling): Refactoring and extract an `IsEagerOp` function into + // `lite:framework` build target. + if (string(op).find("Eager") == 0) { static TfLiteRegistration null_op{ nullptr, nullptr, &UnsupportedTensorFlowOp, nullptr, nullptr, BuiltinOperator_CUSTOM, diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index 13325a8c7c6..45c92a86716 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -24,20 +24,27 @@ limitations under the License. #include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" #ifdef __ANDROID__ +#include #include #endif namespace tflite { void logError(const char* format, ...) { - // TODO(mikie): use android logging, stderr is not captured for Java - // applications - va_list args; - va_start(args, format); - vfprintf(stderr, format, args); - va_end(args); + // stderr is convenient for native tests, but is not captured for apps + va_list args_for_stderr; + va_start(args_for_stderr, format); + vfprintf(stderr, format, args_for_stderr); + va_end(args_for_stderr); fprintf(stderr, "\n"); fflush(stderr); +#ifdef __ANDROID__ + // produce logcat output for general consumption + va_list args_for_log; + va_start(args_for_log, format); + __android_log_vprint(ANDROID_LOG_ERROR, "tflite", format, args_for_log); + va_end(args_for_log); +#endif } #define FATAL(...) \ @@ -564,8 +571,14 @@ TfLiteStatus AddOpsAndParams( nn_op_type = ANEURALNETWORKS_L2_NORMALIZATION; if (reinterpret_cast(node.builtin_data) ->activation != kTfLiteActNone) { - FATAL( + logError( "NNAPI does not support L2Normalization with fused activations"); + return kTfLiteError; + } + if ((node.inputs->size > 0) && + (interpreter->tensor(node.inputs->data[0])->dims->size != 4)) { + logError("NNAPI only supports input rank 4 for L2Normalization"); + return kTfLiteError; } break; case tflite::BuiltinOperator_HASHTABLE_LOOKUP: diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 52ef43d71f2..5ec52035add 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -53,6 +53,7 @@ from tensorflow.core.framework import graph_pb2 as _graph_pb2 from tensorflow.python import keras as _keras from tensorflow.python.client import session as _session from tensorflow.python.framework import graph_util as _tf_graph_util +from tensorflow.python.framework import ops as _ops from tensorflow.python.framework.importer import import_graph_def as _import_graph_def from tensorflow.python.saved_model import signature_constants as _signature_constants from tensorflow.python.saved_model import tag_constants as _tag_constants @@ -193,40 +194,41 @@ class TocoConverter(object): The graph is not frozen. input_arrays or output_arrays contains an invalid tensor name. """ - with _session.Session() as sess: - # Read GraphDef from file. - graph_def = _graph_pb2.GraphDef() - with open(graph_def_file, "rb") as f: - file_content = f.read() - try: - graph_def.ParseFromString(file_content) - except (_text_format.ParseError, DecodeError): + with _ops.Graph().as_default(): + with _session.Session() as sess: + # Read GraphDef from file. + graph_def = _graph_pb2.GraphDef() + with open(graph_def_file, "rb") as f: + file_content = f.read() try: - print("Ignore 'tcmalloc: large alloc' warnings.") - - if not isinstance(file_content, str): - if PY3: - file_content = file_content.decode('utf-8') - else: - file_content = file_content.encode('utf-8') - _text_format.Merge(file_content, graph_def) + graph_def.ParseFromString(file_content) except (_text_format.ParseError, DecodeError): - raise ValueError( - "Unable to parse input file '{}'.".format(graph_def_file)) - sess.graph.as_default() - _import_graph_def(graph_def, name="") + try: + print("Ignore 'tcmalloc: large alloc' warnings.") - # Get input and output tensors. - input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays) - output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays) - _set_tensor_shapes(input_tensors, input_shapes) + if not isinstance(file_content, str): + if PY3: + file_content = file_content.decode("utf-8") + else: + file_content = file_content.encode("utf-8") + _text_format.Merge(file_content, graph_def) + except (_text_format.ParseError, DecodeError): + raise ValueError( + "Unable to parse input file '{}'.".format(graph_def_file)) + _import_graph_def(graph_def, name="") - # Check if graph is frozen. - if not _is_frozen_graph(sess): - raise ValueError("Please freeze the graph using freeze_graph.py.") + # Get input and output tensors. + input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays) + output_tensors = _get_tensors_from_tensor_names(sess.graph, + output_arrays) + _set_tensor_shapes(input_tensors, input_shapes) - # Create TocoConverter class. - return cls(sess.graph_def, input_tensors, output_tensors) + # Check if graph is frozen. + if not _is_frozen_graph(sess): + raise ValueError("Please freeze the graph using freeze_graph.py.") + + # Create TocoConverter class. + return cls(sess.graph_def, input_tensors, output_tensors) @classmethod def from_saved_model(cls, diff --git a/tensorflow/contrib/lite/rpi_makefile.inc b/tensorflow/contrib/lite/rpi_makefile.inc deleted file mode 100644 index 832ef5824be..00000000000 --- a/tensorflow/contrib/lite/rpi_makefile.inc +++ /dev/null @@ -1,33 +0,0 @@ -# Settings for Raspberry Pi. -ifeq ($(TARGET), RPI) - ifeq ($(TARGET_ARCH), armv7) - CXXFLAGS += \ - -march=armv7-a \ - -mfpu=neon-vfpv4 \ - -funsafe-math-optimizations \ - -ftree-vectorize - - CCFLAGS += \ - -march=armv7-a \ - -mfpu=neon-vfpv4 \ - -funsafe-math-optimizations \ - -ftree-vectorize - - LDFLAGS := \ - -Wl,--no-export-dynamic \ - -Wl,--exclude-libs,ALL \ - -Wl,--gc-sections \ - -Wl,--as-needed - endif - - LIBS := \ - -lstdc++ \ - -lpthread \ - -lm \ - -ldl - - OBJDIR := $(OBJDIR)rpi_$(TARGET_ARCH)/ - LIBDIR := $(LIBDIR)rpi_$(TARGET_ARCH)/ - BINDIR := $(BINDIR)rpi_$(TARGET_ARCH)/ - DEPDIR := $(DEPDIR)rpi_$(TARGET_ARCH)/ -endif diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index a788d41ba7b..89912fd116a 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -162,11 +162,12 @@ cc_library( ":test_runner", "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/delegates/eager:delegate", "//tensorflow/contrib/lite/kernels:builtin_ops", ], ) -cc_test( +tf_cc_test( name = "tflite_driver_test", size = "small", srcs = ["tflite_driver_test.cc"], diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 52ef0d5b865..9dd5c8ae449 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -1255,6 +1255,75 @@ def make_conv_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +# Note: This is a regression test for a bug (b/112436267) that Toco incorrectly +# fuses weights when multiple Conv2D/FULLY_CONNECTED ops share the same constant +# weight tensor. +def make_conv_with_shared_weights_tests(zip_path): + """Make a test where 2 Conv ops shared the same constant weight tensor.""" + + test_parameters = [{ + "input_shape": [[1, 10, 10, 3]], + "filter_shape": [[3, 3]], + "strides": [[1, 1, 1, 1]], + "dilations": [[1, 1, 1, 1]], + "padding": ["SAME"], + "data_format": ["NHWC"], + "channel_multiplier": [1], + }] + + def get_tensor_shapes(parameters): + input_shape = parameters["input_shape"] + filter_size = parameters["filter_shape"] + filter_shape = filter_size + [ + input_shape[3], parameters["channel_multiplier"] + ] + return [input_shape, filter_shape] + + def build_graph(parameters): + """Build a conv graph given `parameters`.""" + input_shape, filter_shape = get_tensor_shapes(parameters) + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=input_shape) + + # Construct a constant weights tensor which will be used by both Conv2D. + filter_tensor = tf.constant( + create_tensor_data(np.float32, filter_shape), dtype=tf.float32) + input_tensors = [input_tensor] + + # Construct 2 Conv2D operations which use exactly the same input and + # weights. + result1 = tf.nn.conv2d( + input_tensor, + filter_tensor, + strides=parameters["strides"], + dilations=parameters["dilations"], + padding=parameters["padding"], + data_format=parameters["data_format"]) + result2 = tf.nn.conv2d( + input_tensor, + filter_tensor, + strides=parameters["strides"], + dilations=parameters["dilations"], + padding=parameters["padding"], + data_format=parameters["data_format"]) + # Add MUL ops after Conv2D ops. These MUL ops should be fused into the + # weights of Conv2D. + result1 = result1 * 2 + result2 = result2 * 3 + # Add the 2 results up. + out = result1 + result2 + return input_tensors, [out] + + def build_inputs(parameters, sess, inputs, outputs): + # Build list of input values either containing 1 tensor (input) or 2 tensors + # (input, filter) based on whether filter is constant or variable input. + input_shape, unused_filter_shape = get_tensor_shapes(parameters) + values = [create_tensor_data(np.float32, input_shape)] + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_depthwiseconv_tests(zip_path): """Make a set of tests to do convolution.""" diff --git a/tensorflow/contrib/lite/testing/generate_testspec.cc b/tensorflow/contrib/lite/testing/generate_testspec.cc index f29c188e6c2..62cbeccd331 100644 --- a/tensorflow/contrib/lite/testing/generate_testspec.cc +++ b/tensorflow/contrib/lite/testing/generate_testspec.cc @@ -114,7 +114,13 @@ bool GenerateTestSpecFromTensorflowModel( // different set. std::vector input_values = GenerateInputValues(input_layer, input_layer_type, input_layer_shape); - if (input_values.empty()) return false; + if (input_values.empty()) { + std::cerr << "Unable to generate input values for the TensorFlow model. " + "Make sure the correct values are defined for " + "input_layer, input_layer_type, and input_layer_shape." + << std::endl; + return false; + } // Run TensorFlow. for (int j = 0; j < input_values.size(); j++) { diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index e475f256c01..e67fee2a1ca 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -33,13 +33,18 @@ namespace testing { namespace { bool FLAGS_ignore_known_bugs = true; -// TODO(b/71769302) zip_files_dir should have a more accurate default, if -// possible -string* FLAGS_zip_file_path = new string("./"); +// As archive file names are test-specific, no default is possible. +// +// This test supports input as both zip and tar, as a stock android image does +// not have unzip but does have tar. +string* FLAGS_zip_file_path = new string; +string* FLAGS_tar_file_path = new string; #ifndef __ANDROID__ string* FLAGS_unzip_binary_path = new string("/usr/bin/unzip"); +string* FLAGS_tar_binary_path = new string("/bin/tar"); #else string* FLAGS_unzip_binary_path = new string("/system/bin/unzip"); +string* FLAGS_tar_binary_path = new string("/system/bin/tar"); #endif bool FLAGS_use_nnapi = false; bool FLAGS_ignore_unsupported_nnapi = false; @@ -98,11 +103,11 @@ std::map kBrokenTests = { "77546240"}, }; -// Allows test data to be unzipped into a temporary directory and makes +// Allows test data to be unarchived into a temporary directory and makes // sure those temporary directories are removed later. -class ZipEnvironment : public ::testing::Environment { +class ArchiveEnvironment : public ::testing::Environment { public: - ~ZipEnvironment() override {} + ~ArchiveEnvironment() override {} // Delete all temporary directories on teardown. void TearDown() override { @@ -114,15 +119,26 @@ class ZipEnvironment : public ::testing::Environment { temporary_directories_.clear(); } - // Unzip `zip` file into a new temporary directory `out_dir`. - tensorflow::Status UnZip(const string& zip, string* out_dir) { + // Unarchive `archive` file into a new temporary directory `out_dir`. + tensorflow::Status UnArchive(const string& zip, const string& tar, + string* out_dir) { string dir; TF_CHECK_OK(MakeTemporaryDirectory(&dir)); tensorflow::SubProcess proc; - string unzip_binary = *FLAGS_unzip_binary_path; - TF_CHECK_OK(env->FileExists(unzip_binary)); - TF_CHECK_OK(env->FileExists(zip)); - proc.SetProgram(unzip_binary, {"unzip", "-d", dir, zip}); + if (!zip.empty()) { + string unzip_binary = *FLAGS_unzip_binary_path; + TF_CHECK_OK(env->FileExists(unzip_binary)); + TF_CHECK_OK(env->FileExists(zip)); + proc.SetProgram(unzip_binary, {"unzip", "-d", dir, zip}); + } else { + string tar_binary = *FLAGS_tar_binary_path; + TF_CHECK_OK(env->FileExists(tar_binary)); + TF_CHECK_OK(env->FileExists(tar)); + // 'o' needs to be explicitly set on Android so that + // untarring works as non-root (otherwise tries to chown + // files, which fails) + proc.SetProgram(tar_binary, {"tar", "xfo", tar, "-C", dir}); + } proc.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE); proc.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE); if (!proc.Start()) @@ -156,15 +172,15 @@ class ZipEnvironment : public ::testing::Environment { std::vector temporary_directories_; }; -// Return the singleton zip_environment. -ZipEnvironment* zip_environment() { - static ZipEnvironment* env = new ZipEnvironment; +// Return the singleton archive_environment. +ArchiveEnvironment* archive_environment() { + static ArchiveEnvironment* env = new ArchiveEnvironment; return env; } -// Read the manifest.txt out of the unarchived zip file. Specifically +// Read the manifest.txt out of the unarchived archive file. Specifically // `original_file` is the original zip file for error messages. `dir` is -// the temporary directory where the zip file has been unarchived and +// the temporary directory where the archive file has been unarchived and // `test_paths` is the list of test prefixes that were in the manifest. // Note, it is an error for a manifest to contain no tests. tensorflow::Status ReadManifest(const string& original_file, const string& dir, @@ -190,12 +206,22 @@ tensorflow::Status ReadManifest(const string& original_file, const string& dir, return tensorflow::Status::OK(); } -// Get a list of tests from a zip file `zip_file_name`. -std::vector UnarchiveZipAndFindTestNames(const string& zip_file) { +// Get a list of tests from either zip or tar file +std::vector UnarchiveAndFindTestNames(const string& zip_file, + const string& tar_file) { + if (zip_file.empty() && tar_file.empty()) { + TF_CHECK_OK(tensorflow::Status(tensorflow::error::UNKNOWN, + "Neither zip_file nor tar_file was given")); + } string decompress_tmp_dir; - TF_CHECK_OK(zip_environment()->UnZip(zip_file, &decompress_tmp_dir)); + TF_CHECK_OK(archive_environment()->UnArchive(zip_file, tar_file, + &decompress_tmp_dir)); std::vector stuff; - TF_CHECK_OK(ReadManifest(zip_file, decompress_tmp_dir, &stuff)); + if (!zip_file.empty()) { + TF_CHECK_OK(ReadManifest(zip_file, decompress_tmp_dir, &stuff)); + } else { + TF_CHECK_OK(ReadManifest(tar_file, decompress_tmp_dir, &stuff)); + } return stuff; } @@ -223,8 +249,7 @@ TEST_P(OpsTest, RunZipTests) { string message = test_driver.GetErrorMessage(); if (bug_number.empty()) { if (FLAGS_use_nnapi && FLAGS_ignore_unsupported_nnapi && !result) { - EXPECT_EQ(message, string("Failed to invoke NNAPI interpreter")) - << message; + EXPECT_EQ(message, string("Failed to invoke interpreter")) << message; } else { EXPECT_TRUE(result) << message; } @@ -256,27 +281,34 @@ struct ZipPathParamName { } }; -INSTANTIATE_TEST_CASE_P( - tests, OpsTest, - ::testing::ValuesIn(UnarchiveZipAndFindTestNames(*FLAGS_zip_file_path)), - ZipPathParamName()); +INSTANTIATE_TEST_CASE_P(tests, OpsTest, + ::testing::ValuesIn(UnarchiveAndFindTestNames( + *FLAGS_zip_file_path, *FLAGS_tar_file_path)), + ZipPathParamName()); } // namespace testing } // namespace tflite int main(int argc, char** argv) { - ::testing::AddGlobalTestEnvironment(tflite::testing::zip_environment()); + ::testing::AddGlobalTestEnvironment(tflite::testing::archive_environment()); std::vector flags = { tensorflow::Flag( "ignore_known_bugs", &tflite::testing::FLAGS_ignore_known_bugs, "If a particular model is affected by a known bug, the " "corresponding test should expect the outputs to not match."), - tensorflow::Flag("zip_file_path", tflite::testing::FLAGS_zip_file_path, - "Required: Location of the test zip file."), + tensorflow::Flag( + "tar_file_path", tflite::testing::FLAGS_tar_file_path, + "Required (or zip_file_path): Location of the test tar file."), + tensorflow::Flag( + "zip_file_path", tflite::testing::FLAGS_zip_file_path, + "Required (or tar_file_path): Location of the test zip file."), tensorflow::Flag("unzip_binary_path", tflite::testing::FLAGS_unzip_binary_path, - "Required: Location of a suitable unzip binary."), + "Location of a suitable unzip binary."), + tensorflow::Flag("tar_binary_path", + tflite::testing::FLAGS_tar_binary_path, + "Location of a suitable tar binary."), tensorflow::Flag("use_nnapi", &tflite::testing::FLAGS_use_nnapi, "Whether to enable the NNAPI delegate"), tensorflow::Flag("ignore_unsupported_nnapi", diff --git a/tensorflow/contrib/lite/testing/tf_driver.cc b/tensorflow/contrib/lite/testing/tf_driver.cc index ec435ca60d9..30381ba0283 100644 --- a/tensorflow/contrib/lite/testing/tf_driver.cc +++ b/tensorflow/contrib/lite/testing/tf_driver.cc @@ -179,7 +179,9 @@ void TfDriver::Invoke() { auto status = session_->Run({input_tensors_.begin(), input_tensors_.end()}, output_names_, {}, &output_tensors_); if (!status.ok()) { - Invalidate("Failed to run input data on graph"); + Invalidate( + "Failed to run input data on graph. Make sure the correct value is " + "defined for the input and output arrays."); } } diff --git a/tensorflow/contrib/lite/testing/tflite_diff_flags.h b/tensorflow/contrib/lite/testing/tflite_diff_flags.h index 695c2a3de6c..3874bc31d7d 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_flags.h +++ b/tensorflow/contrib/lite/testing/tflite_diff_flags.h @@ -33,6 +33,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { string input_layer_shape; string output_layer; int32_t num_runs_per_pass = 100; + string delegate; } values; std::vector flags = { @@ -42,18 +43,21 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { "Path of tensorflow lite model."), tensorflow::Flag("input_layer", &values.input_layer, "Names of input tensors, separated by comma. Example: " - "input_1,input_2"), + "input_1,input_2."), tensorflow::Flag("input_layer_type", &values.input_layer_type, "Data types of input tensors, separated by comma. " - "Example: float,int"), + "Example: float,int."), tensorflow::Flag( "input_layer_shape", &values.input_layer_shape, - "Shapes of input tensors, separated by colon. Example: 1,3,4,1:2"), + "Shapes of input tensors, separated by colon. Example: 1,3,4,1:2."), tensorflow::Flag("output_layer", &values.output_layer, - "Names of output tensors, separated by comma. Example " - "output_1,output_2"), + "Names of output tensors, separated by comma. Example: " + "output_1,output_2."), tensorflow::Flag("num_runs_per_pass", &values.num_runs_per_pass, - "Number of full runs in each pass."), + "[optional] Number of full runs in each pass."), + tensorflow::Flag("delegate", &values.delegate, + "[optional] Delegate to use for executing ops. Must be " + "`{\"\", EAGER}`"), }; bool no_inputs = *argc == 1; @@ -61,6 +65,14 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { if (!success || no_inputs || (*argc == 2 && !strcmp(argv[1], "--helpfull"))) { fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str()); return {}; + } else if (values.tensorflow_model.empty() || values.tflite_model.empty() || + values.input_layer.empty() || values.input_layer_type.empty() || + values.input_layer_shape.empty() || values.output_layer.empty()) { + fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str()); + return {}; + } else if (!(values.delegate == "" || values.delegate == "EAGER")) { + fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str()); + return {}; } return {values.tensorflow_model, @@ -69,7 +81,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { Split(values.input_layer_type, ","), Split(values.input_layer_shape, ":"), Split(values.output_layer, ","), - values.num_runs_per_pass}; + values.num_runs_per_pass, + values.delegate}; } } // namespace testing diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.cc b/tensorflow/contrib/lite/testing/tflite_diff_util.cc index 19f34c0a51e..c6ca796ac25 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_util.cc +++ b/tensorflow/contrib/lite/testing/tflite_diff_util.cc @@ -33,7 +33,7 @@ bool RunDiffTest(const DiffOptions& options, int num_invocations) { options.input_layer_shape, options.output_layer)) { return false; } - TfLiteDriver tflite_driver(/*use_nnapi=*/true); + TfLiteDriver tflite_driver(/*use_nnapi=*/true, options.delegate); tflite_driver.LoadModel(options.tflite_model); return tflite::testing::ParseAndRunTests(&tflite_stream, &tflite_driver); } diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.h b/tensorflow/contrib/lite/testing/tflite_diff_util.h index 4ab2f230fdc..f67992139f6 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_util.h +++ b/tensorflow/contrib/lite/testing/tflite_diff_util.h @@ -44,6 +44,9 @@ struct DiffOptions { // each of the passes. The first pass has a single inference, while the // second pass does multiple inferences back to back. int num_runs_per_pass; + // Path to the delegate library to be loaded in order to execute ops. Must be + // `{"", EAGER}`. + string delegate; }; // Run a single TensorFLow Lite diff test with a given options. diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc index 4d08fb54580..71a98a3d568 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.cc +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/delegates/eager/delegate.h" #include "tensorflow/contrib/lite/testing/split.h" namespace tflite { @@ -135,7 +136,13 @@ class TfLiteDriver::Expectation { size_t num_elements_; }; -TfLiteDriver::TfLiteDriver(bool use_nnapi) : use_nnapi_(use_nnapi) {} +TfLiteDriver::TfLiteDriver(bool use_nnapi, const string& delegate_name) + : use_nnapi_(use_nnapi) { + if (delegate_name == "EAGER") { + delegate_.reset(new EagerDelegate()); + } +} + TfLiteDriver::~TfLiteDriver() {} void TfLiteDriver::AllocateTensors() { @@ -165,6 +172,13 @@ void TfLiteDriver::LoadModel(const string& bin_file_path) { } interpreter_->UseNNAPI(use_nnapi_); + if (delegate_) { + if (delegate_->Apply(interpreter_.get()) != kTfLiteOk) { + Invalidate("Unable to the build graph using the delegate"); + return; + } + } + must_allocate_tensors_ = true; } diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h index 5493ba3631b..aed35f877d5 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.h +++ b/tensorflow/contrib/lite/testing/tflite_driver.h @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/contrib/lite/delegates/eager/delegate.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" @@ -28,7 +29,7 @@ namespace testing { // A test runner that feeds inputs into TF Lite and verifies its outputs. class TfLiteDriver : public TestRunner { public: - explicit TfLiteDriver(bool use_nnapi); + explicit TfLiteDriver(bool use_nnapi, const string& delegate = ""); ~TfLiteDriver() override; void LoadModel(const string& bin_file_path) override; @@ -52,6 +53,7 @@ class TfLiteDriver : public TestRunner { class Expectation; + std::unique_ptr delegate_; bool use_nnapi_ = false; std::unique_ptr model_; std::unique_ptr interpreter_; diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index aa4a4d88540..02d0890a7af 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -242,9 +242,11 @@ cc_library( "graph_transformations/resolve_constant_random_uniform.cc", "graph_transformations/resolve_constant_range.cc", "graph_transformations/resolve_constant_reshape.cc", + "graph_transformations/resolve_constant_select.cc", "graph_transformations/resolve_constant_shape_or_rank.cc", "graph_transformations/resolve_constant_slice.cc", "graph_transformations/resolve_constant_strided_slice.cc", + "graph_transformations/resolve_constant_tile.cc", "graph_transformations/resolve_constant_transpose.cc", "graph_transformations/resolve_constant_unary.cc", "graph_transformations/resolve_fake_quant_args_from_vars.cc", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc index 76c6be00d40..b324631579f 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc @@ -274,8 +274,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { return false; } - const auto& weights = model->GetArray(preceding_op->inputs[1]); - const auto& bias = model->GetArray(preceding_op->inputs[2]); + const auto& weights_name = preceding_op->inputs[1]; + const auto& bias_name = preceding_op->inputs[2]; + const auto& weights = model->GetArray(weights_name); + const auto& bias = model->GetArray(bias_name); + const int count_ops_consuming_bias = CountOpsWithInput(*model, bias_name); + const int count_ops_consuming_weights = + CountOpsWithInput(*model, weights_name); + if (binary_op->type == OperatorType::kAdd || binary_op->type == OperatorType::kSub) { if (!bias.buffer) { @@ -285,6 +291,13 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { LogName(*binary_op), LogName(*preceding_op)); return false; } + if (count_ops_consuming_bias > 1) { + AddMessageF( + "Not fusing %s because the bias of the preceding %s is consumed by " + "another op", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } } else { if (!weights.buffer || !bias.buffer) { AddMessageF( @@ -293,6 +306,13 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { LogName(*binary_op), LogName(*preceding_op)); return false; } + if (count_ops_consuming_weights > 1 || count_ops_consuming_bias > 1) { + AddMessageF( + "Not fusing %s because the weights or bias of the preceding %s is " + "consumed by another op", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } } int count_ops_consuming_output = diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 8d9a4c4700e..99f4a7d8f61 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -190,6 +190,8 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantSlice) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStridedSlice) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantGather) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantSelect) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTile) DECLARE_GRAPH_TRANSFORMATION(ResolveMultiplyByZero) DECLARE_GRAPH_TRANSFORMATION(Dequantize) DECLARE_GRAPH_TRANSFORMATION(UnpartitionEmbeddingLookup) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index d26c3b2878b..502de88f7cb 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -274,6 +274,19 @@ bool PropagateMinMaxAmongArrays(Model* model, return changed; } +bool HardcodeMinMaxForReshape(Model* model, Operator* op) { + Array& input = model->GetArray(op->inputs[0]); + Array& output = model->GetArray(op->outputs[0]); + + // If input and output both exist or do not exist, do nothing. + if ((!input.minmax && !output.minmax) || (input.minmax && output.minmax)) { + return false; + } + + // Otherwise propagate info amongst the input and output array. + return PropagateMinMaxAmongArrays(model, {op->inputs[0], op->outputs[0]}); +} + bool HardcodeMinMaxForLstmCell(Model* model, Operator* op) { CHECK_EQ(op->inputs.size(), LstmCellOperator::NUM_INPUTS); CHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS); @@ -370,7 +383,6 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { case OperatorType::kSlice: case OperatorType::kStridedSlice: case OperatorType::kSqueeze: - case OperatorType::kReshape: case OperatorType::kExpandDims: case OperatorType::kPad: case OperatorType::kGather: @@ -416,6 +428,10 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { changed = HardcodeMinMaxForLstmCell(model, op); break; + case OperatorType::kReshape: + changed = HardcodeMinMaxForReshape(model, op); + break; + default: break; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc index 9f5d8b94507..fc49fbda59c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc @@ -48,20 +48,26 @@ void RerouteEdges(const string& from_array, const string& to_array, } // namespace bool RemoveTrivialPassthroughOp(GraphTransformation* transformation, - Model* model, std::size_t op_index) { + Model* model, std::size_t op_index, + int input_index) { const auto passthru_it = model->operators.begin() + op_index; auto* passthru_op = passthru_it->get(); CHECK_EQ(passthru_op->outputs.size(), 1); CHECK_GE(passthru_op->inputs.size(), 1); - int count_nonconstant_input_arrays = 0; - // We call 'main input' the unique nonconstant input array if there is one, - // or else the 0-th input. + int main_input_array_index = 0; - for (int i = 0; i < passthru_op->inputs.size(); i++) { - if (!model->GetArray(passthru_op->inputs[i]).buffer) { - count_nonconstant_input_arrays++; - if (count_nonconstant_input_arrays == 1) { - main_input_array_index = i; + if (input_index != -1) { + main_input_array_index = input_index; + } else { + // We call 'main input' the unique nonconstant input array if there is one, + // or else the 0-th input. + int count_nonconstant_input_arrays = 0; + for (int i = 0; i < passthru_op->inputs.size(); i++) { + if (!model->GetArray(passthru_op->inputs[i]).buffer) { + count_nonconstant_input_arrays++; + if (count_nonconstant_input_arrays == 1) { + main_input_array_index = i; + } } } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h index 9d448c3ee90..663704e5acf 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h @@ -50,7 +50,8 @@ namespace toco { // and then discards it and returns true, or, if it's not trivial (if neither // the input nor the output may be discarded), returns false. bool RemoveTrivialPassthroughOp(GraphTransformation* transformation, - Model* model, std::size_t op_index); + Model* model, std::size_t op_index, + int input_index = -1); } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc new file mode 100644 index 00000000000..e880a3f44da --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc @@ -0,0 +1,78 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +// Resolves a constant Select operation. +// +// This implementation is looking strictly for all-or-nothing on the select +// condition. It's possible to enhance this by looking per-element and possibly +// producing a Mul op. +bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + const auto* base_op = it->get(); + if (base_op->type != OperatorType::kSelect) { + return false; + } + const auto* op = static_cast(base_op); + + CHECK_GE(op->inputs.size(), 3); + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been set by PropagateArrayDataTypes. + return false; + } + if (!output_array.has_shape()) { + // Yield until the output shape has been set by PropagateFixedShapes. + return false; + } + + // We require the cond input to be constant. + if (!IsConstantParameterArray(*model, op->inputs[0])) { + return false; + } + const Array& cond_array = model->GetArray(op->inputs[0]); + CHECK(cond_array.data_type == ArrayDataType::kBool) + << "Only bool conditions are supported"; + const auto& cond_data = cond_array.GetBuffer().data; + if (cond_data.empty()) { + return false; + } + + // Check if the condition is the same for all elements. + bool cond_value = cond_data[0]; + for (size_t i = 1; i < cond_data.size(); ++i) { + if (cond_data[i] != cond_value) { + AddMessageF( + "Cannot resolve %s as constant; cond_array has differing " + "per-element values", + LogName(*op)); + return false; + } + } + + // Pass-through the selected input. + return RemoveTrivialPassthroughOp(this, model, op_index, cond_value ? 1 : 2); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc new file mode 100644 index 00000000000..0b0d0707146 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc @@ -0,0 +1,173 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +// NOTE: the Tile implementation here is taken from tflite's Tile kernel. + +template +void CopyMultipleTimes(const T* in_data, int32_t in_size, int32_t multiplier, + T* out_data) { + for (int i = 0; i < multiplier; ++i) { + const T* in_end = in_data + in_size; + T* new_out_data = std::copy(in_data, in_end, out_data); + in_data = out_data; + out_data = new_out_data; + } +} + +template +std::pair TileOneDimension(const Shape& in_dimensions, + const T* in_data, const M* multipliers, + T* out_data, int dimension) { + const int dimension_size = in_dimensions.dims(dimension); + if (dimension == in_dimensions.dimensions_count() - 1) { + CopyMultipleTimes(in_data, dimension_size, multipliers[dimension], + out_data); + return std::make_pair( + dimension_size, + dimension_size * static_cast(multipliers[dimension])); + } + int total_stride_size = 0, total_tiled_stride_size = 0; + const T* copy_from_data = in_data; + T* copy_to_data = out_data; + for (int i = 0; i < dimension_size; ++i) { + int stride_size = 0, tiled_stride_size = 0; + std::tie(stride_size, tiled_stride_size) = + TileOneDimension(in_dimensions, copy_from_data, multipliers, + copy_to_data, dimension + 1); + copy_from_data += stride_size; + copy_to_data += tiled_stride_size; + total_stride_size += stride_size; + total_tiled_stride_size += tiled_stride_size; + } + CopyMultipleTimes(out_data, total_tiled_stride_size, + multipliers[dimension] - 1, + out_data + total_tiled_stride_size); + return std::make_pair(total_stride_size, + total_tiled_stride_size * multipliers[dimension]); +} + +template +inline void Tile(const Array& input_array, const Array& multiples_array, + Array* output_array) { + // Allocate output storage. + auto& output_data = output_array->GetMutableBuffer().data; + output_data.resize(RequiredBufferSizeForShape(output_array->shape())); + + switch (multiples_array.data_type) { + case ArrayDataType::kInt32: + TileOneDimension( + input_array.shape(), input_array.GetBuffer().data.data(), + multiples_array.GetBuffer().data.data(), + output_array->GetMutableBuffer().data.data(), 0); + break; + case ArrayDataType::kInt64: + TileOneDimension( + input_array.shape(), input_array.GetBuffer().data.data(), + multiples_array.GetBuffer().data.data(), + output_array->GetMutableBuffer().data.data(), 0); + break; + default: + CHECK(false); + break; + } +} + +} // namespace + +// Resolves a constant Tile operation. +bool ResolveConstantTile::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + const auto* base_op = it->get(); + if (base_op->type != OperatorType::kTile) { + return false; + } + const auto* op = static_cast(base_op); + + CHECK_GE(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been set by PropagateArrayDataTypes. + return false; + } + if (!output_array.has_shape()) { + // Yield until the output shape has been set by PropagateFixedShapes. + return false; + } + + // We require constant inputs. + if (!IsConstantParameterArray(*model, op->inputs[0]) || + !IsConstantParameterArray(*model, op->inputs[1])) { + return false; + } + const Array& input_array = model->GetArray(op->inputs[0]); + const Array& multiples_array = model->GetArray(op->inputs[1]); + CHECK(multiples_array.data_type == ArrayDataType::kInt32 || + multiples_array.data_type == ArrayDataType::kInt64) + << "Only int32/int64 indices are supported"; + + // Copy min/max info if present. The ranges of the selected values may be + // a subset of the original range but we want to ensure the quantization + // params stay the same. + if (input_array.minmax) { + const auto& input_minmax = input_array.GetMinMax(); + auto& output_minmax = output_array.GetOrCreateMinMax(); + output_minmax.min = input_minmax.min; + output_minmax.max = input_minmax.max; + } + + CHECK(!output_array.buffer); + switch (output_array.data_type) { + case ArrayDataType::kFloat: + Tile(input_array, multiples_array, &output_array); + break; + case ArrayDataType::kUint8: + Tile(input_array, multiples_array, &output_array); + break; + case ArrayDataType::kInt16: + Tile(input_array, multiples_array, &output_array); + break; + case ArrayDataType::kInt32: + Tile(input_array, multiples_array, &output_array); + break; + case ArrayDataType::kInt64: + Tile(input_array, multiples_array, &output_array); + break; + default: + LOG(FATAL) << "Unsupported data type given to Tile op with output \"" + << op->outputs[0] << "\""; + break; + } + + // Erase input arrays if no longer used after we remove the op. + DeleteArrayIfUsedOnce(op->inputs[0], model); + DeleteArrayIfUsedOnce(op->inputs[1], model); + + // Erase the operator. + model->operators.erase(it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc index fe3882c28df..475415e4814 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc @@ -246,8 +246,8 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { } output_float_data[i] = outval; } - } else if (unary_op->type == OperatorType::kRelu6 && - unary_op->type == OperatorType::kRelu1 && + } else if (unary_op->type == OperatorType::kRelu6 || + unary_op->type == OperatorType::kRelu1 || unary_op->type == OperatorType::kRelu) { for (size_t i = 0; i < output_buffer_size; ++i) { const float value = (*input_float_data)[i]; diff --git a/tensorflow/contrib/lite/toco/toco_port.cc b/tensorflow/contrib/lite/toco/toco_port.cc index 14168fa33f7..204c0d101ea 100644 --- a/tensorflow/contrib/lite/toco/toco_port.cc +++ b/tensorflow/contrib/lite/toco/toco_port.cc @@ -138,13 +138,15 @@ namespace port { #define close _close #define open _open #define read _read -#define O_RDONLY _O_RDONLY -#define O_CREAT _O_CREAT -#define O_WRONLY _O_WRONLY -// Windows does not support the same set of file permissions as other platforms. +// Windows does not support the same set of file permissions as other platforms, +// and also requires an explicit flag for binary file read/write support. constexpr int kFileCreateMode = _S_IREAD | _S_IWRITE; +constexpr int kFileReadFlags = _O_RDONLY | _O_BINARY; +constexpr int kFileWriteFlags = _O_WRONLY | _O_BINARY | _O_CREAT; #else constexpr int kFileCreateMode = 0664; +constexpr int kFileReadFlags = O_RDONLY; +constexpr int kFileWriteFlags = O_CREAT | O_WRONLY; #endif // _WIN32 static bool port_initialized = false; @@ -197,7 +199,7 @@ tensorflow::Status GetContents(const string& path, string* output, const file::Options& options) { output->clear(); - int fd = open(path.c_str(), O_RDONLY); + int fd = open(path.c_str(), kFileReadFlags); if (fd == -1) { return tensorflow::errors::NotFound("can't open() for read"); } @@ -226,7 +228,7 @@ tensorflow::Status GetContents(const string& path, string* output, tensorflow::Status SetContents(const string& filename, const string& contents, const file::Options& options) { - int fd = open(filename.c_str(), O_WRONLY | O_CREAT, kFileCreateMode); + int fd = open(filename.c_str(), kFileWriteFlags, kFileCreateMode); if (fd == -1) { return tensorflow::errors::Internal("can't open() for write"); } diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index fcd3cbab07c..34130a02b03 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -90,8 +90,10 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveConstantRandomUniform); transformations->Add(new ResolveConstantRange); transformations->Add(new ResolveConstantReshape); + transformations->Add(new ResolveConstantSelect); transformations->Add(new ResolveConstantSlice); transformations->Add(new ResolveConstantStridedSlice); + transformations->Add(new ResolveConstantTile); transformations->Add(new ResolveConstantTranspose); transformations->Add(new ResolveConstantUnaryOperator); transformations->Add(new ResolveTensorFlowMerge); diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/tools/make/Makefile similarity index 67% rename from tensorflow/contrib/lite/Makefile rename to tensorflow/contrib/lite/tools/make/Makefile index 9cc8f10b429..e30cc1d70e1 100644 --- a/tensorflow/contrib/lite/Makefile +++ b/tensorflow/contrib/lite/tools/make/Makefile @@ -6,120 +6,74 @@ endif # Try to figure out the host system HOST_OS := ifeq ($(OS),Windows_NT) - HOST_OS = WINDOWS + HOST_OS = windows else UNAME_S := $(shell uname -s) ifeq ($(UNAME_S),Linux) - HOST_OS := LINUX + HOST_OS := linux endif ifeq ($(UNAME_S),Darwin) - HOST_OS := OSX + HOST_OS := osx endif endif HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi) -# Self-hosting -TARGET_ARCH := ${HOST_ARCH} +# Override these on the make command line to target a specific architecture. For example: +# make -f tensorflow/contrib/lite/Makefile TARGET=rpi TARGET_ARCH=armv7l +TARGET := $(HOST_OS) +TARGET_ARCH := $(HOST_ARCH) -# Cross compiling -ifeq ($(CROSS),rpi) - TARGET_ARCH := armv7l - TARGET_TOOLCHAIN_PREFIX := arm-linux-gnueabihf- -endif +# These are the default libraries needed, but they can be added to or +# overridden by the platform-specific settings in target makefiles. +LIBS := \ +-lstdc++ \ +-lpthread \ +-lm \ +-lz -ifeq ($(CROSS),riscv) - TARGET_ARCH := riscv - TARGET_TOOLCHAIN_PREFIX := riscv32-unknown-elf- -endif -ifeq ($(CROSS),stm32f7) - TARGET_ARCH := armf7 - TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- -endif -ifeq ($(CROSS),stm32f1) - TARGET_ARCH := armm1 - TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- -endif - -# Where compiled objects are stored. -OBJDIR := $(MAKEFILE_DIR)/gen/obj/ -BINDIR := $(MAKEFILE_DIR)/gen/bin/ -LIBDIR := $(MAKEFILE_DIR)/gen/lib/ -GENDIR := $(MAKEFILE_DIR)/gen/obj/ - -LIBS := -ifeq ($(TARGET_ARCH),x86_64) - CXXFLAGS += -fPIC -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK -pthread # -msse4.2 -endif - -ifeq ($(TARGET_ARCH),armv7l) - CXXFLAGS += -mfpu=neon -pthread -fPIC - LIBS += -ldl -endif - -ifeq ($(TARGET_ARCH),riscv) -# CXXFLAGS += -march=gap8 - CXXFLAGS += -DTFLITE_MCU - LIBS += -ldl - BUILD_TYPE := micro -endif - -ifeq ($(TARGET_ARCH),armf7) - CXXFLAGS += -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK -DTFLITE_MCU - CXXFLAGS += -fno-rtti -fmessage-length=0 -fno-exceptions -fno-builtin -ffunction-sections -fdata-sections - CXXFLAGS += -funsigned-char -MMD - CXXFLAGS += -mcpu=cortex-m7 -mthumb -mfpu=fpv5-sp-d16 -mfloat-abi=softfp - CXXFLAGS += '-std=gnu++11' '-fno-rtti' '-Wvla' '-c' '-Wall' '-Wextra' '-Wno-unused-parameter' '-Wno-missing-field-initializers' '-fmessage-length=0' '-fno-exceptions' '-fno-builtin' '-ffunction-sections' '-fdata-sections' '-funsigned-char' '-MMD' '-fno-delete-null-pointer-checks' '-fomit-frame-pointer' '-Os' - LIBS += -ldl - BUILD_TYPE := micro -endif -ifeq ($(TARGET_ARCH),armm1) - CXXFLAGS += -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK -mcpu=cortex-m1 -mthumb -DTFLITE_MCU - CXXFLAGS += -fno-rtti -fmessage-length=0 -fno-exceptions -fno-builtin -ffunction-sections -fdata-sections - CXXFLAGS += -funsigned-char -MMD - LIBS += -ldl -endif - -# Settings for the host compiler. -CXX := $(CC_PREFIX) ${TARGET_TOOLCHAIN_PREFIX}g++ -CXXFLAGS += -O3 -DNDEBUG +# There are no rules for compiling objects for the host system (since we don't +# generate things like the protobuf compiler that require that), so all of +# these settings are for the target compiler. +CXXFLAGS := -O3 -DNDEBUG CCFLAGS := ${CXXFLAGS} CXXFLAGS += --std=c++11 -CC := $(CC_PREFIX) ${TARGET_TOOLCHAIN_PREFIX}gcc -AR := $(CC_PREFIX) ${TARGET_TOOLCHAIN_PREFIX}ar CFLAGS := -LDOPTS := -LDOPTS += -L/usr/local/lib +LDOPTS := -L/usr/local/lib ARFLAGS := -r +TARGET_TOOLCHAIN_PREFIX := +CC_PREFIX := + +# These target-specific makefiles should modify or replace options like +# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic +# based on platforms or architectures should happen within these files, to +# keep this main makefile focused on the sources and dependencies. +include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc) + +# Where compiled objects are stored. +GENDIR := $(MAKEFILE_DIR)/gen/$(TARGET)_$(TARGET_ARCH)/ +OBJDIR := $(GENDIR)obj/ +BINDIR := $(GENDIR)bin/ +LIBDIR := $(GENDIR)lib/ INCLUDES := \ -I. \ --I$(MAKEFILE_DIR)/../../../ \ --I$(MAKEFILE_DIR)/../../../../ \ +-I$(MAKEFILE_DIR)/../../../../../ \ +-I$(MAKEFILE_DIR)/../../../../../../ \ -I$(MAKEFILE_DIR)/downloads/ \ -I$(MAKEFILE_DIR)/downloads/eigen \ -I$(MAKEFILE_DIR)/downloads/gemmlowp \ -I$(MAKEFILE_DIR)/downloads/neon_2_sse \ -I$(MAKEFILE_DIR)/downloads/farmhash/src \ -I$(MAKEFILE_DIR)/downloads/flatbuffers/include \ --I$(GENDIR) +-I$(OBJDIR) # This is at the end so any globally-installed frameworks like protobuf don't # override local versions in the source tree. INCLUDES += -I/usr/local/include -LIBS += \ --lstdc++ \ --lpthread \ --lm \ --lz - -# If we're on Linux, also link in the dl library. -ifeq ($(HOST_OS),LINUX) - LIBS += -ldl -endif - -include $(MAKEFILE_DIR)/ios_makefile.inc -include $(MAKEFILE_DIR)/rpi_makefile.inc +CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++ +CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc +AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar # This library is the main target for this makefile. It will contain a minimal # runtime that can be linked in to other programs. @@ -163,8 +117,8 @@ $(wildcard tensorflow/contrib/lite/kernels/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.c) \ -$(wildcard tensorflow/contrib/lite/downloads/farmhash/src/farmhash.cc) \ -$(wildcard tensorflow/contrib/lite/downloads/fft2d/fftsg.c) +$(wildcard tensorflow/contrib/lite/tools/make/downloads/farmhash/src/farmhash.cc) \ +$(wildcard tensorflow/contrib/lite/tools/make/downloads/fft2d/fftsg.c) endif # Remove any duplicates. CORE_CC_ALL_SRCS := $(sort $(CORE_CC_ALL_SRCS)) @@ -179,10 +133,6 @@ ifeq ($(BUILD_TYPE),micro) CORE_CC_EXCLUDE_SRCS += \ tensorflow/contrib/lite/mmap_allocation.cc \ tensorflow/contrib/lite/nnapi_delegate.cc -else -CORE_CC_EXCLUDE_SRCS += \ -tensorflow/contrib/lite/mmap_allocation_disabled.cc \ -tensorflow/contrib/lite/nnapi_delegate_disabled.cc endif # Filter out all the excluded files. TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS)) diff --git a/tensorflow/contrib/lite/build_ios_universal_lib.sh b/tensorflow/contrib/lite/tools/make/build_ios_universal_lib.sh similarity index 66% rename from tensorflow/contrib/lite/build_ios_universal_lib.sh rename to tensorflow/contrib/lite/tools/make/build_ios_universal_lib.sh index 31df43a1754..fe056945a65 100755 --- a/tensorflow/contrib/lite/build_ios_universal_lib.sh +++ b/tensorflow/contrib/lite/tools/make/build_ios_universal_lib.sh @@ -17,23 +17,23 @@ set -e SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -cd "$SCRIPT_DIR/../../.." +cd "$SCRIPT_DIR/../../../../.." # Build library for supported architectures and packs them in a fat binary. make_library() { for arch in x86_64 armv7 armv7s arm64 do - make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=${arch} \ - -j 8 \ - $SCRIPT_DIR/gen/lib/ios_${arch}/${1} + make -f tensorflow/contrib/lite/tools/make/Makefile TARGET=ios TARGET_ARCH=${arch} \ + -j 8 done + mkdir -p tensorflow/contrib/lite/tools/make/gen/lib lipo \ - tensorflow/contrib/lite/gen/lib/ios_x86_64/${1} \ - tensorflow/contrib/lite/gen/lib/ios_armv7/${1} \ - tensorflow/contrib/lite/gen/lib/ios_armv7s/${1} \ - tensorflow/contrib/lite/gen/lib/ios_arm64/${1} \ + tensorflow/contrib/lite/tools/make/gen/ios_x86_64/lib/${1} \ + tensorflow/contrib/lite/tools/make/gen/ios_armv7/lib/${1} \ + tensorflow/contrib/lite/tools/make/gen/ios_armv7s/lib/${1} \ + tensorflow/contrib/lite/tools/make/gen/ios_arm64/lib/${1} \ -create \ - -output tensorflow/contrib/lite/gen/lib/${1} + -output tensorflow/contrib/lite/tools/make/gen/lib/${1} } make_library libtensorflow-lite.a diff --git a/tensorflow/contrib/lite/build_rpi_lib.sh b/tensorflow/contrib/lite/tools/make/build_rpi_lib.sh similarity index 90% rename from tensorflow/contrib/lite/build_rpi_lib.sh rename to tensorflow/contrib/lite/tools/make/build_rpi_lib.sh index 3824b16412e..24ecd4356df 100755 --- a/tensorflow/contrib/lite/build_rpi_lib.sh +++ b/tensorflow/contrib/lite/tools/make/build_rpi_lib.sh @@ -17,6 +17,6 @@ set -e SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -cd "$SCRIPT_DIR/../../.." +cd "$SCRIPT_DIR/../../../../.." -CC_PREFIX=arm-linux-gnueabihf- make -j 3 -f tensorflow/contrib/lite/Makefile TARGET=RPI TARGET_ARCH=armv7 +CC_PREFIX=arm-linux-gnueabihf- make -j 3 -f tensorflow/contrib/lite/tools/make/Makefile TARGET=rpi TARGET_ARCH=armv7l diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/tools/make/download_dependencies.sh similarity index 98% rename from tensorflow/contrib/lite/download_dependencies.sh rename to tensorflow/contrib/lite/tools/make/download_dependencies.sh index 8c7df474d55..29afa451337 100755 --- a/tensorflow/contrib/lite/download_dependencies.sh +++ b/tensorflow/contrib/lite/tools/make/download_dependencies.sh @@ -17,9 +17,9 @@ set -e SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -cd "$SCRIPT_DIR/../../.." +cd "$SCRIPT_DIR/../../../../.." -DOWNLOADS_DIR=tensorflow/contrib/lite/downloads +DOWNLOADS_DIR=tensorflow/contrib/lite/tools/make/downloads BZL_FILE_PATH=tensorflow/workspace.bzl # Ensure it is being run from repo root diff --git a/tensorflow/contrib/lite/ios_makefile.inc b/tensorflow/contrib/lite/tools/make/targets/ios_makefile.inc similarity index 67% rename from tensorflow/contrib/lite/ios_makefile.inc rename to tensorflow/contrib/lite/tools/make/targets/ios_makefile.inc index 079320586ff..7f36b8ecef4 100644 --- a/tensorflow/contrib/lite/ios_makefile.inc +++ b/tensorflow/contrib/lite/tools/make/targets/ios_makefile.inc @@ -1,11 +1,11 @@ # Settings for iOS. -ifeq ($(TARGET), IOS) - BUILD_FOR_IOS_SIMULATOR := false - ifeq ($(IOS_ARCH), x86_64) - BUILD_FOR_IOS_SIMULATOR := true +ifeq ($(TARGET), ios) + BUILD_FOR_IOS_SIMULATOR := false + ifeq ($(TARGET_ARCH), x86_64) + BUILD_FOR_IOS_SIMULATOR := true endif - ifeq ($(IOS_ARCH), i386) - BUILD_FOR_IOS_SIMULATOR := true + ifeq ($(TARGET_ARCH), i386) + BUILD_FOR_IOS_SIMULATOR := true endif ifeq ($(BUILD_FOR_IOS_SIMULATOR), true) IPHONEOS_PLATFORM := $(shell xcrun --sdk iphonesimulator \ @@ -18,8 +18,8 @@ ifeq ($(TARGET), IOS) endif IOS_SDK_VERSION := $(shell xcrun --sdk iphoneos --show-sdk-version) MIN_SDK_VERSION := 9.0 - # Override IOS_ARCH with armv7, armv7s, arm64, i386, or x86_64. - IOS_ARCH := x86_64 + # Override TARGET_ARCH with armv7, armv7s, arm64, i386, or x86_64. + TARGET_ARCH := x86_64 CXXFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \ -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ -DTFLITE_USE_APPLE_ACCELERATE_FOR_CONV \ @@ -29,21 +29,17 @@ ifeq ($(TARGET), IOS) -fno-exceptions \ -isysroot \ ${IPHONEOS_SYSROOT} \ - -arch $(IOS_ARCH) \ + -arch $(TARGET_ARCH) \ -O3 CCFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \ -fembed-bitcode \ -mno-thumb \ -isysroot \ ${IPHONEOS_SYSROOT} \ - -arch $(IOS_ARCH) \ + -arch $(TARGET_ARCH) \ -O3 LDFLAGS := -fembed-bitcode \ -miphoneos-version-min=${MIN_SDK_VERSION} \ -framework Accelerate \ - -arch $(IOS_ARCH) - OBJDIR := $(OBJDIR)ios_$(IOS_ARCH)/ - LIBDIR := $(LIBDIR)ios_$(IOS_ARCH)/ - BINDIR := $(BINDIR)ios_$(IOS_ARCH)/ - DEPDIR := $(DEPDIR)ios_$(IOS_ARCH)/ + -arch $(TARGET_ARCH) endif diff --git a/tensorflow/contrib/lite/tools/make/targets/linux_makefile.inc b/tensorflow/contrib/lite/tools/make/targets/linux_makefile.inc new file mode 100644 index 00000000000..86499da99e2 --- /dev/null +++ b/tensorflow/contrib/lite/tools/make/targets/linux_makefile.inc @@ -0,0 +1,10 @@ +# Settings for Linux. +ifeq ($(TARGET), linux) + CXXFLAGS += \ + -fPIC \ + -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ + -pthread + # TODO(petewarden): In the future we may want to add architecture-specific + # flags like -msse4.2 + LIBS += -ldl +endif diff --git a/tensorflow/contrib/lite/tools/make/targets/riscv_makefile.inc b/tensorflow/contrib/lite/tools/make/targets/riscv_makefile.inc new file mode 100644 index 00000000000..1a82afec33e --- /dev/null +++ b/tensorflow/contrib/lite/tools/make/targets/riscv_makefile.inc @@ -0,0 +1,10 @@ +# Settings for RiscV platforms. +ifeq ($(TARGET), riscv) + TARGET_ARCH := riscv + TARGET_TOOLCHAIN_PREFIX := riscv32-unknown-elf- + + #CXXFLAGS += -march=gap8 + CXXFLAGS += -DTFLITE_MCU + LIBS += -ldl + BUILD_TYPE := micro +endif diff --git a/tensorflow/contrib/lite/tools/make/targets/rpi_makefile.inc b/tensorflow/contrib/lite/tools/make/targets/rpi_makefile.inc new file mode 100644 index 00000000000..1ad0c502372 --- /dev/null +++ b/tensorflow/contrib/lite/tools/make/targets/rpi_makefile.inc @@ -0,0 +1,60 @@ +# Settings for Raspberry Pi. +ifeq ($(TARGET),rpi) + # Default to the architecture used on the Pi Two/Three (ArmV7), but override this + # with TARGET_ARCH=armv6 to build for the Pi Zero or One. + TARGET_ARCH := armv7l + TARGET_TOOLCHAIN_PREFIX := arm-linux-gnueabihf- + + ifeq ($(TARGET_ARCH), armv7l) + CXXFLAGS += \ + -march=armv7-a \ + -mfpu=neon-vfpv4 \ + -funsafe-math-optimizations \ + -ftree-vectorize \ + -fPIC + + CCFLAGS += \ + -march=armv7-a \ + -mfpu=neon-vfpv4 \ + -funsafe-math-optimizations \ + -ftree-vectorize \ + -fPIC + + LDFLAGS := \ + -Wl,--no-export-dynamic \ + -Wl,--exclude-libs,ALL \ + -Wl,--gc-sections \ + -Wl,--as-needed + endif + + # TODO(petewarden) In the future, we'll want to use OpenBLAS as a faster + # alternative to Eigen on non-NEON ARM hardware like armv6. + ifeq ($(TARGET_ARCH), armv6) + CXXFLAGS += \ + -march=armv6 \ + -mfpu=vfp \ + -funsafe-math-optimizations \ + -ftree-vectorize \ + -fPIC + + CCFLAGS += \ + -march=armv6 \ + -mfpu=vfp \ + -funsafe-math-optimizations \ + -ftree-vectorize \ + -fPIC + + LDFLAGS := \ + -Wl,--no-export-dynamic \ + -Wl,--exclude-libs,ALL \ + -Wl,--gc-sections \ + -Wl,--as-needed + endif + + LIBS := \ + -lstdc++ \ + -lpthread \ + -lm \ + -ldl + +endif diff --git a/tensorflow/contrib/lite/tools/make/targets/stm32f1_makefile.inc b/tensorflow/contrib/lite/tools/make/targets/stm32f1_makefile.inc new file mode 100644 index 00000000000..7418e4d196e --- /dev/null +++ b/tensorflow/contrib/lite/tools/make/targets/stm32f1_makefile.inc @@ -0,0 +1,21 @@ +# Settings for STM32F1 platforms. +ifeq ($(TARGET), stm32f1) + TARGET_ARCH := armm1 + TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- + + CXXFLAGS += \ + -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ + -mcpu=cortex-m1 \ + -mthumb \ + -DTFLITE_MCU \ + -fno-rtti \ + -fmessage-length=0 \ + -fno-exceptions \ + -fno-builtin \ + -ffunction-sections \ + -fdata-sections \ + -funsigned-char \ + -MMD + LIBS += -ldl + BUILD_TYPE := micro +endif diff --git a/tensorflow/contrib/lite/tools/make/targets/stm32f7_makefile.inc b/tensorflow/contrib/lite/tools/make/targets/stm32f7_makefile.inc new file mode 100644 index 00000000000..48af71e5b4b --- /dev/null +++ b/tensorflow/contrib/lite/tools/make/targets/stm32f7_makefile.inc @@ -0,0 +1,41 @@ +# Settings for STM32F7 platforms. +ifeq ($(TARGET), stm32f7) + TARGET_ARCH := armf7 + TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- + + CXXFLAGS += \ + -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ + -DTFLITE_MCU \ + -fno-rtti \ + -fmessage-length=0 \ + -fno-exceptions \ + -fno-builtin \ + -ffunction-sections \ + -fdata-sections \ + -funsigned-char \ + -MMD \ + -mcpu=cortex-m7 \ + -mthumb \ + -mfpu=fpv5-sp-d16 \ + -mfloat-abi=softfp \ + -std=gnu++11 \ + -fno-rtti \ + -Wvla \ + -c \ + -Wall \ + -Wextra \ + -Wno-unused-parameter \ + -Wno-missing-field-initializers \ + -fmessage-length=0 \ + -fno-exceptions \ + -fno-builtin \ + -ffunction-sections \ + -fdata-sections \ + -funsigned-char \ + -MMD \ + -fno-delete-null-pointer-checks \ + -fomit-frame-pointer \ + -Os + LIBS += -ldl + BUILD_TYPE := micro +endif diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index 4942d941765..8c0bfefb303 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -20,7 +20,6 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_lookup_ops from tensorflow.python.ops import lookup_ops # pylint: disable=unused-import @@ -395,17 +394,12 @@ class MutableHashTable(LookupInterface): Raises: TypeError: when `keys` do not match the table data types. """ - if keys.dtype.base_dtype != self._key_dtype: - raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % - (self._key_dtype, keys.dtype)) - with ops.name_scope(name, "%s_lookup_table_find" % self._name, (self._table_ref, keys, self._default_value)) as name: + keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") with ops.colocate_with(self._table_ref): values = gen_lookup_ops.lookup_table_find_v2( self._table_ref, keys, self._default_value, name=name) - - values.set_shape(keys.get_shape().concatenate(self._value_shape)) return values def insert(self, keys, values, name=None): @@ -451,9 +445,6 @@ class MutableHashTable(LookupInterface): with ops.colocate_with(self._table_ref): exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( self._table_ref, self._key_dtype, self._value_dtype, name=name) - - exported_values.set_shape(exported_keys.get_shape().concatenate( - self._value_shape)) return exported_keys, exported_values class _Saveable(BaseSaverBuilder.SaveableObject): @@ -537,14 +528,15 @@ class MutableDenseHashTable(LookupInterface): ValueError: If checkpoint is True and no name was specified. """ self._default_value = ops.convert_to_tensor( - default_value, dtype=value_dtype) + default_value, dtype=value_dtype, name="default_value") self._value_shape = self._default_value.get_shape() # The table must be shared if checkpointing is requested for multi-worker # training to work correctly. Use the node name if no shared_name has been # explicitly specified. use_node_name_sharing = checkpoint and shared_name is None - empty_key = ops.convert_to_tensor(empty_key, dtype=key_dtype) + empty_key = ops.convert_to_tensor( + empty_key, dtype=key_dtype, name="empty_key") self._table_ref = gen_lookup_ops.mutable_dense_hash_table_v2( empty_key=empty_key, shared_name=shared_name, @@ -591,20 +583,13 @@ class MutableDenseHashTable(LookupInterface): Raises: TypeError: when `keys` do not match the table data types. """ - if keys.dtype.base_dtype != self._key_dtype: - raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % - (self._key_dtype, keys.dtype)) - with ops.name_scope(name, "%s_lookup_table_find" % self._name, [self._table_ref, keys]) as name: + keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") with ops.colocate_with(self._table_ref): values = gen_lookup_ops.lookup_table_find_v2( self._table_ref, keys, self._default_value, name=name) - if keys.get_shape().ndims is not None and keys.get_shape().ndims > 0: - values.set_shape( - tensor_shape.TensorShape([keys.get_shape().dims[0]]).concatenate( - self._value_shape)) return values def insert(self, keys, values, name=None): @@ -624,11 +609,11 @@ class MutableDenseHashTable(LookupInterface): TypeError: when `keys` or `values` doesn't match the table data types. """ - # pylint: disable=protected-access - lookup_ops._check_table_dtypes(self, keys.dtype, values.dtype) - # pylint: enable=protected-access with ops.name_scope(name, "%s_lookup_table_insert" % self._name, [self._table_ref, keys, values]) as name: + keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") + values = ops.convert_to_tensor( + values, dtype=self._value_dtype, name="values") with ops.colocate_with(self._table_ref): op = gen_lookup_ops.lookup_table_insert_v2( self._table_ref, keys, values, name=name) @@ -650,8 +635,6 @@ class MutableDenseHashTable(LookupInterface): exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( self._table_ref, self._key_dtype, self._value_dtype, name=name) - exported_values.set_shape(exported_keys.get_shape().concatenate( - self._value_shape)) return exported_keys, exported_values class _Saveable(BaseSaverBuilder.SaveableObject): diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 8d510ede582..6fb5244fc62 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -434,8 +434,10 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual([[0, 1], [2, 3], [-1, -1]], result) exported_keys, exported_values = table.export() - self.assertAllEqual([None], exported_keys.get_shape().as_list()) - self.assertAllEqual([None, 2], exported_values.get_shape().as_list()) + self.assertAllEqual([None], exported_keys.get_shape().as_list(), + msg="Saw shape %s" % exported_keys.shape) + self.assertAllEqual([None, 2], exported_values.get_shape().as_list(), + msg="Saw shape %s" % exported_values.shape) # exported data is in the order of the internal map, i.e. undefined sorted_keys = np.sort(exported_keys.eval()) sorted_values = np.sort(exported_values.eval()) @@ -669,7 +671,7 @@ class MutableHashTableOpTest(test.TestCase): # lookup with keys of the wrong type input_string = constant_op.constant([1, 2, 3], dtypes.int64) - with self.assertRaises(TypeError): + with self.assertRaises(ValueError): table.lookup(input_string).eval() # default value of the wrong type @@ -853,7 +855,8 @@ class MutableDenseHashTableOpTest(test.TestCase): input_string = constant_op.constant([11, 12, 15], dtypes.int64) output = table.lookup(input_string) - self.assertAllEqual([3, 4], output.get_shape()) + self.assertAllEqual( + [3, 4], output.shape, msg="Saw shape: %s" % output.shape) result = output.eval() self.assertAllEqual([[0, 1, 2, 3], [3, 4, 5, 6], [-1, -2, -3, -4]], diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index 448ae6d22e6..dc9b17a6278 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -35,7 +35,9 @@ NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\. # process. For now we're hardcoding to the version which is used by # TensorFlow 1.9. PROTOBUF_URL="https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz" -RE2_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" +# TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' once +# the archive has been propagated in mirror.bazel.build. +RE2_URL="$(grep -o 'https://github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" DOUBLE_CONVERSION_URL="$(grep -o "https.*google/double-conversion.*\.zip" "${BZL_FILE_PATH}" | head -n1)" ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" diff --git a/tensorflow/contrib/metrics/python/metrics/classification.py b/tensorflow/contrib/metrics/python/metrics/classification.py index e5536122698..7053907da05 100644 --- a/tensorflow/contrib/metrics/python/metrics/classification.py +++ b/tensorflow/contrib/metrics/python/metrics/classification.py @@ -24,7 +24,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics_impl from tensorflow.python.ops import variable_scope -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context # TODO(nsilberman): move into metrics/python/ops/ @@ -174,7 +174,7 @@ def f1_score(labels, predictions, weights=None, num_thresholds=200, ops.add_to_collections(metrics_collections, best_f1) return best_f1 - best_f1 = distribute_lib.get_tower_context().merge_call( + best_f1 = distribution_strategy_context.get_tower_context().merge_call( f1_across_towers, values) update_op = compute_best_f1_score(tp=update_ops['tp'], fp=update_ops['fp'], diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 8c11d8bcfdf..f6ecaba8346 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -34,6 +34,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import optimizer as optimizer_v1 from tensorflow.python.training import slot_creator from tensorflow.python.training.checkpointable import base as checkpointable @@ -620,7 +621,7 @@ class OptimizerV2(optimizer_v1.Optimizer): # Map from graph_key to state for that graph. We use the graph_key # since it works in both eager and graph mode, and gives the outer # graph inside functions. - tower_context = distribute_lib.get_tower_context() + tower_context = distribution_strategy_context.get_tower_context() if tower_context is None: # In a cross-tower context for a DistributionStrategy, which means # only one Optimizer will be created, not one per tower. @@ -769,7 +770,8 @@ class OptimizerV2(optimizer_v1.Optimizer): distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN) if scale_loss_by_num_towers: - num_towers = distribute_lib.get_distribution_strategy().num_towers + num_towers = distribution_strategy_context.get_distribution_strategy( + ).num_towers if num_towers > 1: loss_value *= 1. / num_towers @@ -788,7 +790,8 @@ class OptimizerV2(optimizer_v1.Optimizer): distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN) if scale_loss_by_num_towers: - num_towers = distribute_lib.get_distribution_strategy().num_towers + num_towers = distribution_strategy_context.get_distribution_strategy( + ).num_towers if num_towers > 1: loss *= 1. / num_towers @@ -862,7 +865,7 @@ class OptimizerV2(optimizer_v1.Optimizer): if not filtered: raise ValueError("No gradients provided for any variable: %s." % ([str(v) for _, v in grads_and_vars],)) - return distribute_lib.get_tower_context().merge_call( + return distribution_strategy_context.get_tower_context().merge_call( self._distributed_apply, filtered, global_step=global_step, name=name) def _get_or_create_state(self, var_list=None): diff --git a/tensorflow/contrib/optimizer_v2/rmsprop.py b/tensorflow/contrib/optimizer_v2/rmsprop.py index 164ff0ea067..3de53405ec1 100644 --- a/tensorflow/contrib/optimizer_v2/rmsprop.py +++ b/tensorflow/contrib/optimizer_v2/rmsprop.py @@ -22,7 +22,7 @@ A detailed description of rmsprop. - divide gradient by the root of this average mean_square = decay * mean_square{t-1} + (1-decay) * gradient ** 2 -mom = momentum * mom{t-1} + learning_rate * g_t / sqrt(mean_square + epsilon) +mom = momentum * mom{t-1} + learning_rate * g_t / sqrt(mean_square) delta = - mom This implementation of RMSProp uses plain momentum, not Nesterov momentum. @@ -33,7 +33,7 @@ gradients, and uses that average to estimate the variance: mean_grad = decay * mean_square{t-1} + (1-decay) * gradient mean_square = decay * mean_square{t-1} + (1-decay) * gradient ** 2 mom = momentum * mom{t-1} + learning_rate * g_t / - sqrt(mean_square - mean_grad**2 + epsilon) + sqrt(mean_square - mean_grad**2) delta = - mom """ @@ -43,7 +43,6 @@ from __future__ import print_function from tensorflow.contrib.optimizer_v2 import optimizer_v2 from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops from tensorflow.python.training import training_ops @@ -87,7 +86,8 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2): decay: A float hyperparameter. Discounting factor for the history/coming gradient. momentum: A float hyperparameter. - epsilon: A float hyperparameter. Small value to avoid zero denominator. + epsilon: A float hyperparameter. Small value to initialize the average + square gradient variable and avoid zero denominator. use_locking: If True use locks for update operation. centered: If True, gradients are normalized by the estimated variance of the gradient; if False, by the uncentered second moment. Setting this to @@ -106,10 +106,8 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2): def _create_vars(self, var_list, state): for v in var_list: - if v.get_shape().is_fully_defined(): - init_rms = init_ops.ones_initializer(dtype=v.dtype.base_dtype) - else: - init_rms = array_ops.ones_like(v) + init_rms = state.get_hyper( + "epsilon", v.dtype.base_dtype) * array_ops.ones_like(v) state.create_slot_with_initializer(v, init_rms, v.get_shape(), v.dtype.base_dtype, "rms") if self._centered: @@ -129,7 +127,9 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2): state.get_hyper("learning_rate", var.dtype.base_dtype), state.get_hyper("decay", var.dtype.base_dtype), state.get_hyper("momentum", var.dtype.base_dtype), - state.get_hyper("epsilon", var.dtype.base_dtype), + # epsilon is now the rms initial value and is not added to the + # denominator anymore, hence calling the kernel op with epsilon=0. + 0, grad, use_locking=self._use_locking).op else: @@ -140,7 +140,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2): state.get_hyper("learning_rate", var.dtype.base_dtype), state.get_hyper("decay", var.dtype.base_dtype), state.get_hyper("momentum", var.dtype.base_dtype), - state.get_hyper("epsilon", var.dtype.base_dtype), + 0, grad, use_locking=self._use_locking).op @@ -157,7 +157,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2): state.get_hyper("learning_rate", var.dtype.base_dtype), state.get_hyper("decay", var.dtype.base_dtype), state.get_hyper("momentum", var.dtype.base_dtype), - state.get_hyper("epsilon", var.dtype.base_dtype), + 0, grad, use_locking=self._use_locking) else: @@ -168,7 +168,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2): state.get_hyper("learning_rate", var.dtype.base_dtype), state.get_hyper("decay", var.dtype.base_dtype), state.get_hyper("momentum", var.dtype.base_dtype), - state.get_hyper("epsilon", var.dtype.base_dtype), + 0, grad, use_locking=self._use_locking) @@ -185,7 +185,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2): state.get_hyper("learning_rate", var.dtype.base_dtype), state.get_hyper("decay", var.dtype.base_dtype), state.get_hyper("momentum", var.dtype.base_dtype), - state.get_hyper("epsilon", var.dtype.base_dtype), + 0, grad.values, grad.indices, use_locking=self._use_locking) @@ -197,7 +197,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2): state.get_hyper("learning_rate", var.dtype.base_dtype), state.get_hyper("decay", var.dtype.base_dtype), state.get_hyper("momentum", var.dtype.base_dtype), - state.get_hyper("epsilon", var.dtype.base_dtype), + 0, grad.values, grad.indices, use_locking=self._use_locking) @@ -215,7 +215,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2): state.get_hyper("learning_rate", var.dtype.base_dtype), state.get_hyper("decay", var.dtype.base_dtype), state.get_hyper("momentum", var.dtype.base_dtype), - state.get_hyper("epsilon", var.dtype.base_dtype), + 0, grad, indices, use_locking=self._use_locking) @@ -227,7 +227,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2): state.get_hyper("learning_rate", var.dtype.base_dtype), state.get_hyper("decay", var.dtype.base_dtype), state.get_hyper("momentum", var.dtype.base_dtype), - state.get_hyper("epsilon", var.dtype.base_dtype), + 0, grad, indices, use_locking=self._use_locking) diff --git a/tensorflow/contrib/optimizer_v2/rmsprop_test.py b/tensorflow/contrib/optimizer_v2/rmsprop_test.py index dc23ef241a4..628d0418dd3 100644 --- a/tensorflow/contrib/optimizer_v2/rmsprop_test.py +++ b/tensorflow/contrib/optimizer_v2/rmsprop_test.py @@ -39,34 +39,34 @@ _DATA_TYPES = [dtypes.half, dtypes.float32] _TEST_PARAM_VALUES = [ # learning_rate, decay, momentum, epsilon, centered, use_resource - [0.5, 0.9, 0.0, 1e-3, True, False], - [0.5, 0.9, 0.0, 1e-3, False, False], - [0.5, 0.9, 0.0, 1e-3, True, True], - [0.5, 0.9, 0.0, 1e-3, False, True], - [0.1, 0.9, 0.0, 1e-3, True, False], - [0.5, 0.95, 0.0, 1e-3, False, False], - [0.5, 0.95, 0.0, 1e-5, True, False], - [0.5, 0.95, 0.9, 1e-5, True, False], + [0.5, 0.9, 0.0, 1.0, True, False], + [0.5, 0.9, 0.0, 1.0, False, False], + [0.5, 0.9, 0.0, 1.0, True, True], + [0.5, 0.9, 0.0, 1.0, False, True], + [0.1, 0.9, 0.0, 1.0, True, False], + [0.5, 0.95, 0.0, 1.0, False, False], + [0.5, 0.8, 0.0, 1e-3, True, False], + [0.5, 0.8, 0.9, 1e-3, True, False], ] class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase): def _rmsprop_update_numpy(self, var, g, mg, rms, mom, lr, decay, momentum, - epsilon, centered): + centered): rms_t = rms * decay + (1 - decay) * g * g - denom_t = rms_t + epsilon if centered: mg_t = mg * decay + (1 - decay) * g - denom_t -= mg_t * mg_t + denom_t = rms_t - mg_t * mg_t else: mg_t = mg + denom_t = rms_t mom_t = momentum * mom + lr * g / np.sqrt(denom_t, dtype=denom_t.dtype) var_t = var - mom_t return var_t, mg_t, rms_t, mom_t def _sparse_rmsprop_update_numpy(self, var, gindexs, gvalues, mg, rms, mom, - lr, decay, momentum, epsilon, centered): + lr, decay, momentum, centered): mg_t = copy.deepcopy(mg) rms_t = copy.deepcopy(rms) mom_t = copy.deepcopy(mom) @@ -75,7 +75,7 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase): gindex = gindexs[i] gvalue = gvalues[i] rms_t[gindex] = rms[gindex] * decay + (1 - decay) * gvalue * gvalue - denom_t = rms_t[gindex] + epsilon + denom_t = rms_t[gindex] if centered: mg_t[gindex] = mg_t[gindex] * decay + (1 - decay) * gvalue denom_t -= mg_t[gindex] * mg_t[gindex] @@ -129,8 +129,8 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase): mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) - rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + rms0_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype) + rms1_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype) mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) @@ -144,10 +144,10 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase): var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy( var0_np, grads0_np, mg0_np, rms0_np, mom0_np, learning_rate, - decay, momentum, epsilon, centered) + decay, momentum, centered) var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy( var1_np, grads1_np, mg1_np, rms1_np, mom1_np, learning_rate, - decay, momentum, epsilon, centered) + decay, momentum, centered) # Validate updated params if centered: @@ -191,7 +191,7 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase): loss = pred * pred sgd_op = rmsprop.RMSPropOptimizer( learning_rate=1.0, - decay=0.0, + decay=0.1, momentum=0.0, epsilon=1.0, centered=True).minimize(loss) @@ -202,7 +202,7 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase): sgd_op.run() # Validate updated params self.assertAllCloseAccordingToType( - [[-111, -138]], var0.eval(), atol=0.01) + [[-7/3.0, -4/3.0]], var0.eval(), atol=0.01) @parameterized.named_parameters( *test_util.generate_combinations_with_testcase_name( @@ -251,8 +251,8 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase): mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) - rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + rms0_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype) + rms1_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype) mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) @@ -266,10 +266,10 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase): var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy( var0_np, grads0_np_indices, grads0_np, mg0_np, rms0_np, mom0_np, - learning_rate, decay, momentum, epsilon, centered) + learning_rate, decay, momentum, centered) var1_np, mg1_np, rms1_np, mom1_np = self._sparse_rmsprop_update_numpy( var1_np, grads1_np_indices, grads1_np, mg1_np, rms1_np, mom1_np, - learning_rate, decay, momentum, epsilon, centered) + learning_rate, decay, momentum, centered) # Validate updated params if centered: @@ -317,13 +317,13 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase): # Check the parameters. self.assertAllCloseAccordingToType( np.array([ - 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)), - 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) + 1.0 - (0.1 * 2.0 / math.sqrt(0.901)), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901)) ]), var0.eval()) self.assertAllCloseAccordingToType( np.array([ - 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)), - 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001)) ]), var1.eval()) # Step 2: the root mean square accumulators contain the previous update. update.run() @@ -335,17 +335,17 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase): # Check the parameters. self.assertAllCloseAccordingToType( np.array([ - 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)), - 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)) + 1.0 - (0.1 * 2.0 / math.sqrt(0.901)) - + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001)), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901)) - + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001)) ]), var0.eval()) self.assertAllCloseAccordingToType( np.array([ - 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)), - 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)) + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)) - + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5)), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001)) - + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5)) ]), var1.eval()) @parameterized.parameters(_DATA_TYPES) @@ -357,7 +357,7 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase): grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) opt = rmsprop.RMSPropOptimizer( - learning_rate=2.0, decay=0.9, momentum=0.5, epsilon=1e-5) + learning_rate=2.0, decay=0.9, momentum=0.5, epsilon=1.0) update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() @@ -383,22 +383,22 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase): np.array([0.90001, 0.90001]), rms1.eval()) # Check the momentum accumulators self.assertAllCloseAccordingToType( - np.array([(0.1 * 2.0 / math.sqrt(0.901 + 1e-5)), - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5))]), mom0.eval()) + np.array([(0.1 * 2.0 / math.sqrt(0.901)), + (0.1 * 2.0 / math.sqrt(0.901))]), mom0.eval()) self.assertAllCloseAccordingToType( - np.array([(0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)), - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5))]), mom1.eval()) + np.array([(0.01 * 2.0 / math.sqrt(0.90001)), + (0.01 * 2.0 / math.sqrt(0.90001))]), mom1.eval()) # Check that the parameters. self.assertAllCloseAccordingToType( np.array([ - 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)), - 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + 1.0 - (0.1 * 2.0 / math.sqrt(0.901)), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901)) ]), var0.eval()) self.assertAllCloseAccordingToType( np.array([ - 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)), - 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001)) ]), var1.eval()) # Step 2: the root mean square accumulators contain the previous update. @@ -410,38 +410,38 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase): np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval()) self.assertAllCloseAccordingToType( np.array([ - 0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)), - 0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)) + 0.5 * (0.1 * 2.0 / math.sqrt(0.901)) + + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001)), + 0.5 * (0.1 * 2.0 / math.sqrt(0.901)) + + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001)) ]), mom0.eval()) self.assertAllCloseAccordingToType( np.array([ - 0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)), - 0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)) + 0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) + + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5)), + 0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) + + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5)) ]), mom1.eval()) # Check the parameters. self.assertAllCloseAccordingToType( np.array([ - 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) - - (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))), - 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) - - (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))) + 1.0 - (0.1 * 2.0 / math.sqrt(0.901)) - + (0.5 * (0.1 * 2.0 / math.sqrt(0.901)) + + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001))), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901)) - + (0.5 * (0.1 * 2.0 / math.sqrt(0.901)) + + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001))) ]), var0.eval()) self.assertAllCloseAccordingToType( np.array([ - 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) - - (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))), - 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) - - (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))) + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)) - + (0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) + + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5))), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001)) - + (0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) + + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5))) ]), var1.eval()) diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index cb66fd1f76b..2ddbd73ea64 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -455,6 +455,24 @@ class _LayerMatch(object): return self._bias_add_op +def _FollowedByFakeQuant(tensor): + """Returns True if the tensor is followed by a FakeQuant.""" + fake_quant_ops = set([ + 'FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs', + 'FakeQuantWithMinMaxVarsPerChannel' + ]) + pass_through_ops = set(['Reshape', 'Identity']) + consumers = tensor.consumers() + while consumers: + c = consumers.pop() + if c.type in fake_quant_ops: + return True + elif c.type in pass_through_ops: + for output in c.outputs: + consumers.extend(output.consumers()) + return False + + def _InsertQuantOp(context, name, producer, @@ -535,11 +553,7 @@ def _InsertQuantOp(context, # Prevent ops from being quantized multiple times. Bypass ops can sometimes # overlap between multiple matches, so we need to ensure that we don't # add duplicate FakeQuant operations. - fake_quant_ops = set([ - 'FakeQuantWithMinMaxVars', - 'FakeQuantWithMinMaxArgs' - ]) - if fake_quant_ops.intersection(set([c.type for c in inputs.consumers()])): + if _FollowedByFakeQuant(inputs): return if moving_avg: diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py index 06ebcdfee16..212d902a3c6 100644 --- a/tensorflow/contrib/quantize/python/quantize_test.py +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -471,6 +471,60 @@ class QuantizeTest(test_util.TensorFlowTestCase): self.assertTrue( 'part/test/test/weights_quant/FakeQuantWithMinMaxVars' in op_names) + def testSkipReshapeQuantization(self): + self._RunTestOverParameters(self._TestSkipReshapeQuantization) + + def _TestSkipReshapeQuantization(self, is_training): + graph = ops.Graph() + with graph.as_default(): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + conv = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=nn_ops.relu6, + scope='test/test') + + reshape = array_ops.reshape( + conv, (int(10), int(height / 2), int(width / 2), int(16))) + + # Insert a fake quant node after the reshape. We will check that one isn't + # insert before. + array_ops.fake_quant_with_min_max_vars(reshape, -1, 1) + + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + + # Ensure that there isn't a FakeQuant added before the reshape. + self.assertFalse( + 'FakeQuantWithMinMaxVars' in [i.op.type for i in reshape.op.inputs]) + + graph = ops.Graph() + with graph.as_default(): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + conv = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=nn_ops.relu6, + scope='test/test') + + reshape = array_ops.reshape( + conv, (int(10), int(height / 2), int(width / 2), int(16))) + + # If no fake quant is added after the reshape, a FakeQuant should be added + # before the reshape. + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + + # Ensure that there isn't a FakeQuant added before the reshape. + self.assertTrue( + 'FakeQuantWithMinMaxVars' in [i.op.type for i in reshape.op.inputs]) + def _WeightInit(self, stddev): """Returns truncated normal variable initializer. diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 2a84629080d..5874245d58e 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -149,7 +149,7 @@ cuda_py_tests( cuda_py_tests( name = "core_rnn_test", - size = "large", + size = "medium", srcs = ["python/kernel_tests/core_rnn_test.py"], additional_deps = [ ":rnn_py", @@ -175,7 +175,7 @@ cuda_py_tests( tf_py_test( name = "fused_rnn_cell_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/fused_rnn_cell_test.py"], additional_deps = [ ":rnn_py", @@ -192,10 +192,6 @@ tf_py_test( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], - tags = [ - "manual", - "notap", - ], ) cuda_py_tests( diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index 1c20d88fe4b..d62ec45d186 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -1288,7 +1288,10 @@ class LSTMTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testDynamicEquivalentToStaticRNN(self): self._testDynamicEquivalentToStaticRNN(use_sequence_length=False) - self._testDynamicEquivalentToStaticRNN(use_sequence_length=False) + + @test_util.run_in_graph_and_eager_modes + def testDynamicEquivalentToStaticRNNWithSequenceLength(self): + self._testDynamicEquivalentToStaticRNN(use_sequence_length=True) class BidirectionalRNNTest(test.TestCase): diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD index fbb50befdfb..e7eb4ac5635 100644 --- a/tensorflow/contrib/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/BUILD @@ -113,7 +113,6 @@ py_test( size = "small", srcs = ["python/saved_model/keras_saved_model_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], deps = [ ":saved_model_py", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/stat_summarizer/BUILD b/tensorflow/contrib/stat_summarizer/BUILD index 0b8fc0cdc66..412a2c81a14 100644 --- a/tensorflow/contrib/stat_summarizer/BUILD +++ b/tensorflow/contrib/stat_summarizer/BUILD @@ -31,8 +31,5 @@ tf_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:variables", ], - tags = [ - "no_windows", - "notap", # TODO(b/80546574): test is flaky - ], + tags = ["notap"], # TODO(b/80546574): test is flaky ) diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index 35e8c92aba3..8fa0b3ada94 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -22,10 +22,12 @@ from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib - from tensorflow.contrib.tensor_forest.client import eval_metrics from tensorflow.contrib.tensor_forest.python import tensor_forest - +from tensorflow.python.estimator import estimator as core_estimator +from tensorflow.python.estimator.canned import head as core_head_lib +from tensorflow.python.estimator.export.export_output import PredictOutput +from tensorflow.python.feature_column import feature_column as fc_core from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops @@ -34,12 +36,12 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util - KEYS_NAME = 'keys' LOSS_NAME = 'rf_training_loss' TREE_PATHS_PREDICTION_KEY = 'tree_paths' @@ -48,6 +50,11 @@ ALL_SERVING_KEY = 'tensorforest_all' EPSILON = 0.000001 +class ModelBuilderOutputType(object): + MODEL_FN_OPS = 0 + ESTIMATOR_SPEC = 1 + + class TensorForestRunOpAtEndHook(session_run_hook.SessionRunHook): def __init__(self, op_dict): @@ -106,20 +113,34 @@ class TensorForestLossHook(session_run_hook.SessionRunHook): run_context.request_stop() -def get_default_head(params, weights_name, name=None): - if params.regression: - return head_lib.regression_head( - weight_column_name=weights_name, - label_dimension=params.num_outputs, - enable_centered_bias=False, - head_name=name) +def _get_default_head(params, weights_name, output_type, name=None): + """Creates a default head based on a type of a problem.""" + if output_type == ModelBuilderOutputType.MODEL_FN_OPS: + if params.regression: + return head_lib.regression_head( + weight_column_name=weights_name, + label_dimension=params.num_outputs, + enable_centered_bias=False, + head_name=name) + else: + return head_lib.multi_class_head( + params.num_classes, + weight_column_name=weights_name, + enable_centered_bias=False, + head_name=name) else: - return head_lib.multi_class_head( - params.num_classes, - weight_column_name=weights_name, - enable_centered_bias=False, - head_name=name) - + if params.regression: + return core_head_lib._regression_head( # pylint:disable=protected-access + weight_column=weights_name, + label_dimension=params.num_outputs, + name=name, + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + else: + return core_head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint:disable=protected-access + n_classes=params.num_classes, + weight_column=weights_name, + name=name, + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) def get_model_fn(params, graph_builder_class, @@ -135,19 +156,27 @@ def get_model_fn(params, report_feature_importances=False, local_eval=False, head_scope=None, - include_all_in_serving=False): + include_all_in_serving=False, + output_type=ModelBuilderOutputType.MODEL_FN_OPS): """Return a model function given a way to construct a graph builder.""" if model_head is None: - model_head = get_default_head(params, weights_name) + model_head = _get_default_head(params, weights_name, output_type) def _model_fn(features, labels, mode): """Function that returns predictions, training loss, and training op.""" + if (isinstance(features, ops.Tensor) or isinstance(features, sparse_tensor.SparseTensor)): features = {'features': features} if feature_columns: features = features.copy() - features.update(layers.transform_features(features, feature_columns)) + + if output_type == ModelBuilderOutputType.MODEL_FN_OPS: + features.update(layers.transform_features(features, feature_columns)) + else: + for fc in feature_columns: + tensor = fc_core._transform_features(features, [fc])[fc] # pylint: disable=protected-access + features[fc.name] = tensor weights = None if weights_name and weights_name in features: @@ -201,52 +230,95 @@ def get_model_fn(params, def _train_fn(unused_loss): return training_graph - model_ops = model_head.create_model_fn_ops( - features=features, - labels=labels, - mode=mode, - train_op_fn=_train_fn, - logits=logits, - scope=head_scope) # Ops are run in lexigraphical order of their keys. Run the resource # clean-up op last. all_handles = graph_builder.get_all_resource_handles() ops_at_end = { - '9: clean up resources': control_flow_ops.group( - *[resource_variable_ops.destroy_resource_op(handle) - for handle in all_handles])} + '9: clean up resources': + control_flow_ops.group(*[ + resource_variable_ops.destroy_resource_op(handle) + for handle in all_handles + ]) + } if report_feature_importances: ops_at_end['1: feature_importances'] = ( graph_builder.feature_importances()) - training_hooks.append(TensorForestRunOpAtEndHook(ops_at_end)) + training_hooks = [TensorForestRunOpAtEndHook(ops_at_end)] - if early_stopping_rounds: - training_hooks.append( - TensorForestLossHook( - early_stopping_rounds, - early_stopping_loss_threshold=early_stopping_loss_threshold, - loss_op=model_ops.loss)) + if output_type == ModelBuilderOutputType.MODEL_FN_OPS: + model_ops = model_head.create_model_fn_ops( + features=features, + labels=labels, + mode=mode, + train_op_fn=_train_fn, + logits=logits, + scope=head_scope) - model_ops.training_hooks.extend(training_hooks) + if early_stopping_rounds: + training_hooks.append( + TensorForestLossHook( + early_stopping_rounds, + early_stopping_loss_threshold=early_stopping_loss_threshold, + loss_op=model_ops.loss)) - if keys is not None: - model_ops.predictions[keys_name] = keys + model_ops.training_hooks.extend(training_hooks) - if params.inference_tree_paths: - model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths + if keys is not None: + model_ops.predictions[keys_name] = keys - model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance - if include_all_in_serving: - # In order to serve the variance we need to add the prediction dict - # to output_alternatives dict. - if not model_ops.output_alternatives: - model_ops.output_alternatives = {} - model_ops.output_alternatives[ALL_SERVING_KEY] = ( - constants.ProblemType.UNSPECIFIED, model_ops.predictions) - return model_ops + if params.inference_tree_paths: + model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths + + model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance + + if include_all_in_serving: + # In order to serve the variance we need to add the prediction dict + # to output_alternatives dict. + if not model_ops.output_alternatives: + model_ops.output_alternatives = {} + model_ops.output_alternatives[ALL_SERVING_KEY] = ( + constants.ProblemType.UNSPECIFIED, model_ops.predictions) + + return model_ops + + else: + # Estimator spec + estimator_spec = model_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_fn, + logits=logits) + + if early_stopping_rounds: + training_hooks.append( + TensorForestLossHook( + early_stopping_rounds, + early_stopping_loss_threshold=early_stopping_loss_threshold, + loss_op=estimator_spec.loss)) + + estimator_spec = estimator_spec._replace( + training_hooks=training_hooks + list(estimator_spec.training_hooks)) + if keys is not None: + estimator_spec.predictions[keys_name] = keys + if params.inference_tree_paths: + estimator_spec.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths + estimator_spec.predictions[VARIANCE_PREDICTION_KEY] = regression_variance + + if include_all_in_serving: + outputs = estimator_spec.export_outputs + if not outputs: + outputs = {} + outputs = {ALL_SERVING_KEY: PredictOutput(estimator_spec.predictions)} + print(estimator_spec.export_outputs) + # In order to serve the variance we need to add the prediction dict + # to output_alternatives dict. + estimator_spec = estimator_spec._replace(export_outputs=outputs) + + return estimator_spec return _model_fn @@ -493,8 +565,11 @@ class MultiForestMultiHeadEstimator(estimator.Estimator): params, graph_builder_class, device_assigner, - model_head=get_default_head( - params, weight_column, name='head{0}'.format(i)), + model_head=_get_default_head( + params, + weight_column, + name='head{0}'.format(i), + output_type=ModelBuilderOutputType.MODEL_FN_OPS), weights_name=weight_column, keys_name=keys_column, early_stopping_rounds=early_stopping_rounds, @@ -509,3 +584,142 @@ class MultiForestMultiHeadEstimator(estimator.Estimator): model_dir=model_dir, config=config, feature_engineering_fn=feature_engineering_fn) + + +class CoreTensorForestEstimator(core_estimator.Estimator): + """A CORE estimator that can train and evaluate a random forest. + + Example: + + ```python + params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams( + num_classes=2, num_features=40, num_trees=10, max_nodes=1000) + + # Estimator using the default graph builder. + estimator = CoreTensorForestEstimator(params, model_dir=model_dir) + + # Or estimator using TrainingLossForest as the graph builder. + estimator = CoreTensorForestEstimator( + params, graph_builder_class=tensor_forest.TrainingLossForest, + model_dir=model_dir) + + # Input builders + def input_fn_train: # returns x, y + ... + def input_fn_eval: # returns x, y + ... + estimator.train(input_fn=input_fn_train) + estimator.evaluate(input_fn=input_fn_eval) + + # Predict returns an iterable of dicts. + results = list(estimator.predict(x=x)) + prob0 = results[0][eval_metrics.INFERENCE_PROB_NAME] + prediction0 = results[0][eval_metrics.INFERENCE_PRED_NAME] + ``` + """ + + def __init__(self, + params, + device_assigner=None, + model_dir=None, + feature_columns=None, + graph_builder_class=tensor_forest.RandomForestGraphs, + config=None, + weight_column=None, + keys_column=None, + feature_engineering_fn=None, + early_stopping_rounds=100, + early_stopping_loss_threshold=0.001, + num_trainers=1, + trainer_id=0, + report_feature_importances=False, + local_eval=False, + version=None, + head=None, + include_all_in_serving=False): + """Initializes a TensorForestEstimator instance. + + Args: + params: ForestHParams object that holds random forest hyperparameters. + These parameters will be passed into `model_fn`. + device_assigner: An `object` instance that controls how trees get + assigned to devices. If `None`, will use + `tensor_forest.RandomForestDeviceAssigner`. + model_dir: Directory to save model parameters, graph, etc. To continue + training a previously saved model, load checkpoints saved to this + directory into an estimator. + feature_columns: An iterable containing all the feature columns used by + the model. All items in the set should be instances of classes derived + from `_FeatureColumn`. + graph_builder_class: An `object` instance that defines how TF graphs for + random forest training and inference are built. By default will use + `tensor_forest.RandomForestGraphs`. Can be overridden by version + kwarg. + config: `RunConfig` object to configure the runtime settings. + weight_column: A string defining feature column name representing + weights. Will be multiplied by the loss of the example. Used to + downweight or boost examples during training. + keys_column: A string naming one of the features to strip out and + pass through into the inference/eval results dict. Useful for + associating specific examples with their prediction. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + early_stopping_rounds: Allows training to terminate early if the forest is + no longer growing. 100 by default. Set to a Falsy value to disable + the default training hook. + early_stopping_loss_threshold: Percentage (as fraction) that loss must + improve by within early_stopping_rounds steps, otherwise training will + terminate. + num_trainers: Number of training jobs, which will partition trees + among them. + trainer_id: Which trainer this instance is. + report_feature_importances: If True, print out feature importances + during evaluation. + local_eval: If True, don't use a device assigner for eval. This is to + support some common setups where eval is done on a single machine, even + though training might be distributed. + version: Unused. + head: A heads_lib.Head object that calculates losses and such. If None, + one will be automatically created based on params. + include_all_in_serving: if True, allow preparation of the complete + prediction dict including the variance to be exported for serving with + the Servo lib; and it also requires calling export_savedmodel with + default_output_alternative_key=ALL_SERVING_KEY, i.e. + estimator.export_savedmodel(export_dir_base=your_export_dir, + serving_input_fn=your_export_input_fn, + default_output_alternative_key=ALL_SERVING_KEY) + if False, resort to default behavior, i.e. export scores and + probabilities but no variances. In this case + default_output_alternative_key should be None while calling + export_savedmodel(). + Note, that due to backward compatibility we cannot always set + include_all_in_serving to True because in this case calling + export_saved_model() without + default_output_alternative_key=ALL_SERVING_KEY (legacy behavior) the + saved_model_export_utils.get_output_alternatives() would raise + ValueError. + + Returns: + A `TensorForestEstimator` instance. + """ + + super(CoreTensorForestEstimator, self).__init__( + model_fn=get_model_fn( + params.fill(), + graph_builder_class, + device_assigner, + feature_columns=feature_columns, + model_head=head, + weights_name=weight_column, + keys_name=keys_column, + early_stopping_rounds=early_stopping_rounds, + early_stopping_loss_threshold=early_stopping_loss_threshold, + num_trainers=num_trainers, + trainer_id=trainer_id, + report_feature_importances=report_feature_importances, + local_eval=local_eval, + include_all_in_serving=include_all_in_serving, + output_type=ModelBuilderOutputType.ESTIMATOR_SPEC), + model_dir=model_dir, + config=config) diff --git a/tensorflow/contrib/tensor_forest/client/random_forest_test.py b/tensorflow/contrib/tensor_forest/client/random_forest_test.py index ac42364d257..aa0016b7408 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest_test.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest_test.py @@ -23,7 +23,39 @@ import numpy as np from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.contrib.tensor_forest.client import random_forest from tensorflow.contrib.tensor_forest.python import tensor_forest +from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column_lib as core_feature_column +from tensorflow.python.framework import ops +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_utils + + +def _get_classification_input_fns(): + iris = base.load_iris() + data = iris.data.astype(np.float32) + labels = iris.target.astype(np.int32) + + train_input_fn = numpy_io.numpy_input_fn( + x=data, y=labels, batch_size=150, num_epochs=None, shuffle=False) + + predict_input_fn = numpy_io.numpy_input_fn( + x=data[:1,], y=None, batch_size=1, num_epochs=1, shuffle=False) + return train_input_fn, predict_input_fn + + +def _get_regression_input_fns(): + boston = base.load_boston() + data = boston.data.astype(np.float32) + labels = boston.target.astype(np.int32) + + train_input_fn = numpy_io.numpy_input_fn( + x=data, y=labels, batch_size=506, num_epochs=None, shuffle=False) + + predict_input_fn = numpy_io.numpy_input_fn( + x=data[:1,], y=None, batch_size=1, num_epochs=1, shuffle=False) + return train_input_fn, predict_input_fn class TensorForestTrainerTests(test.TestCase): @@ -39,18 +71,22 @@ class TensorForestTrainerTests(test.TestCase): inference_tree_paths=True) classifier = random_forest.TensorForestEstimator(hparams.fill()) - iris = base.load_iris() - data = iris.data.astype(np.float32) - labels = iris.target.astype(np.int32) + input_fn, predict_input_fn = _get_classification_input_fns() + classifier.fit(input_fn=input_fn, steps=100) + res = classifier.evaluate(input_fn=input_fn, steps=10) - classifier.fit(x=data, y=labels, steps=100, batch_size=50) - classifier.evaluate(x=data, y=labels, steps=10) + self.assertEqual(1.0, res['accuracy']) + self.assertAllClose(0.55144483, res['loss']) + + predictions = list(classifier.predict(input_fn=predict_input_fn)) + self.assertAllClose([[0.576117, 0.211942, 0.211942]], + [pred['probabilities'] for pred in predictions]) def testRegression(self): - """Tests multi-class classification using matrix data as input.""" + """Tests regression using matrix data as input.""" hparams = tensor_forest.ForestHParams( - num_trees=3, + num_trees=5, max_nodes=1000, num_classes=1, num_features=13, @@ -59,12 +95,263 @@ class TensorForestTrainerTests(test.TestCase): regressor = random_forest.TensorForestEstimator(hparams.fill()) - boston = base.load_boston() - data = boston.data.astype(np.float32) - labels = boston.target.astype(np.int32) + input_fn, predict_input_fn = _get_regression_input_fns() - regressor.fit(x=data, y=labels, steps=100, batch_size=50) - regressor.evaluate(x=data, y=labels, steps=10) + regressor.fit(input_fn=input_fn, steps=100) + res = regressor.evaluate(input_fn=input_fn, steps=10) + self.assertGreaterEqual(0.1, res['loss']) + + predictions = list(regressor.predict(input_fn=predict_input_fn)) + self.assertAllClose([24.], [pred['scores'] for pred in predictions], atol=1) + + def testAdditionalOutputs(self): + """Tests multi-class classification using matrix data as input.""" + hparams = tensor_forest.ForestHParams( + num_trees=1, + max_nodes=100, + num_classes=3, + num_features=4, + split_after_samples=20, + inference_tree_paths=True) + classifier = random_forest.TensorForestEstimator( + hparams.fill(), keys_column='keys', include_all_in_serving=True) + + iris = base.load_iris() + data = iris.data.astype(np.float32) + labels = iris.target.astype(np.int32) + + input_fn = numpy_io.numpy_input_fn( + x={ + 'x': data, + 'keys': np.arange(len(iris.data)).reshape(150, 1) + }, + y=labels, + batch_size=10, + num_epochs=1, + shuffle=False) + + classifier.fit(input_fn=input_fn, steps=100) + predictions = list(classifier.predict(input_fn=input_fn)) + # Check that there is a key column, tree paths and var. + for pred in predictions: + self.assertTrue('keys' in pred) + self.assertTrue('tree_paths' in pred) + self.assertTrue('prediction_variance' in pred) + + def _assert_checkpoint(self, model_dir, global_step): + reader = checkpoint_utils.load_checkpoint(model_dir) + self.assertLessEqual( + reader.get_tensor(ops.GraphKeys.GLOBAL_STEP), global_step) + + def testEarlyStopping(self): + """Tests multi-class classification using matrix data as input.""" + hparams = tensor_forest.ForestHParams( + num_trees=100, + max_nodes=10000, + num_classes=3, + num_features=4, + split_after_samples=20, + inference_tree_paths=True) + classifier = random_forest.TensorForestEstimator( + hparams.fill(), + # Set a crazy threshold - 30% loss change. + early_stopping_loss_threshold=0.3, + early_stopping_rounds=2) + + input_fn, _ = _get_classification_input_fns() + classifier.fit(input_fn=input_fn, steps=100) + + # We stopped early. + self._assert_checkpoint(classifier.model_dir, global_step=5) + + +class CoreTensorForestTests(test.TestCase): + + def testTrainEvaluateInferDoesNotThrowErrorForClassifier(self): + head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss( + n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + hparams = tensor_forest.ForestHParams( + num_trees=3, + max_nodes=1000, + num_classes=3, + num_features=4, + split_after_samples=20, + inference_tree_paths=True) + + est = random_forest.CoreTensorForestEstimator(hparams.fill(), head=head_fn) + + input_fn, predict_input_fn = _get_classification_input_fns() + + est.train(input_fn=input_fn, steps=100) + res = est.evaluate(input_fn=input_fn, steps=1) + + self.assertEqual(1.0, res['accuracy']) + self.assertAllClose(0.55144483, res['loss']) + + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose([[0.576117, 0.211942, 0.211942]], + [pred['probabilities'] for pred in predictions]) + + def testRegression(self): + """Tests regression using matrix data as input.""" + head_fn = head_lib._regression_head( + label_dimension=1, + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + hparams = tensor_forest.ForestHParams( + num_trees=5, + max_nodes=1000, + num_classes=1, + num_features=13, + regression=True, + split_after_samples=20) + + regressor = random_forest.CoreTensorForestEstimator( + hparams.fill(), head=head_fn) + + input_fn, predict_input_fn = _get_regression_input_fns() + + regressor.train(input_fn=input_fn, steps=100) + res = regressor.evaluate(input_fn=input_fn, steps=10) + self.assertGreaterEqual(0.1, res['loss']) + + predictions = list(regressor.predict(input_fn=predict_input_fn)) + self.assertAllClose( + [[24.]], [pred['predictions'] for pred in predictions], atol=1) + + def testWithFeatureColumns(self): + head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss( + n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + hparams = tensor_forest.ForestHParams( + num_trees=3, + max_nodes=1000, + num_classes=3, + num_features=4, + split_after_samples=20, + inference_tree_paths=True) + + est = random_forest.CoreTensorForestEstimator( + hparams.fill(), + head=head_fn, + feature_columns=[core_feature_column.numeric_column('x')]) + + iris = base.load_iris() + data = {'x': iris.data.astype(np.float32)} + labels = iris.target.astype(np.int32) + + input_fn = numpy_io.numpy_input_fn( + x=data, y=labels, batch_size=150, num_epochs=None, shuffle=False) + + est.train(input_fn=input_fn, steps=100) + res = est.evaluate(input_fn=input_fn, steps=1) + + self.assertEqual(1.0, res['accuracy']) + self.assertAllClose(0.55144483, res['loss']) + + def testAutofillsClassificationHead(self): + hparams = tensor_forest.ForestHParams( + num_trees=3, + max_nodes=1000, + num_classes=3, + num_features=4, + split_after_samples=20, + inference_tree_paths=True) + + est = random_forest.CoreTensorForestEstimator(hparams.fill()) + + input_fn, _ = _get_classification_input_fns() + + est.train(input_fn=input_fn, steps=100) + res = est.evaluate(input_fn=input_fn, steps=1) + + self.assertEqual(1.0, res['accuracy']) + self.assertAllClose(0.55144483, res['loss']) + + def testAutofillsRegressionHead(self): + hparams = tensor_forest.ForestHParams( + num_trees=5, + max_nodes=1000, + num_classes=1, + num_features=13, + regression=True, + split_after_samples=20) + + regressor = random_forest.CoreTensorForestEstimator(hparams.fill()) + + input_fn, predict_input_fn = _get_regression_input_fns() + + regressor.train(input_fn=input_fn, steps=100) + res = regressor.evaluate(input_fn=input_fn, steps=10) + self.assertGreaterEqual(0.1, res['loss']) + + predictions = list(regressor.predict(input_fn=predict_input_fn)) + self.assertAllClose( + [[24.]], [pred['predictions'] for pred in predictions], atol=1) + + def testAdditionalOutputs(self): + """Tests multi-class classification using matrix data as input.""" + hparams = tensor_forest.ForestHParams( + num_trees=1, + max_nodes=100, + num_classes=3, + num_features=4, + split_after_samples=20, + inference_tree_paths=True) + classifier = random_forest.CoreTensorForestEstimator( + hparams.fill(), keys_column='keys', include_all_in_serving=True) + + iris = base.load_iris() + data = iris.data.astype(np.float32) + labels = iris.target.astype(np.int32) + + input_fn = numpy_io.numpy_input_fn( + x={ + 'x': data, + 'keys': np.arange(len(iris.data)).reshape(150, 1) + }, + y=labels, + batch_size=10, + num_epochs=1, + shuffle=False) + + classifier.train(input_fn=input_fn, steps=100) + predictions = list(classifier.predict(input_fn=input_fn)) + # Check that there is a key column, tree paths and var. + for pred in predictions: + self.assertTrue('keys' in pred) + self.assertTrue('tree_paths' in pred) + self.assertTrue('prediction_variance' in pred) + + def _assert_checkpoint(self, model_dir, global_step): + reader = checkpoint_utils.load_checkpoint(model_dir) + self.assertLessEqual( + reader.get_tensor(ops.GraphKeys.GLOBAL_STEP), global_step) + + def testEarlyStopping(self): + head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss( + n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + hparams = tensor_forest.ForestHParams( + num_trees=3, + max_nodes=1000, + num_classes=3, + num_features=4, + split_after_samples=20, + inference_tree_paths=True) + + est = random_forest.CoreTensorForestEstimator( + hparams.fill(), + head=head_fn, + # Set a crazy threshold - 30% loss change. + early_stopping_loss_threshold=0.3, + early_stopping_rounds=2) + + input_fn, _ = _get_classification_input_fns() + est.train(input_fn=input_fn, steps=100) + # We stopped early. + self._assert_checkpoint(est.model_dir, global_step=8) if __name__ == "__main__": diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 2abf402e6cf..56e451e2e37 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -265,7 +265,6 @@ tf_py_test( ":datasets", ], grpc_enabled = True, - tags = ["no_windows"], ) tf_py_test( diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py index 19f088f8b86..d4ccb0f2467 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py @@ -20,7 +20,7 @@ from __future__ import print_function from setuptools import setup -_VERSION = '1.9.0' +_VERSION = '1.10.0' CONSOLE_SCRIPTS = [ 'capture_tpu_profile=cloud_tpu_profiler.main:run_main', diff --git a/tensorflow/contrib/tpu/profiler/version.h b/tensorflow/contrib/tpu/profiler/version.h index 1bf49966d12..aee094177bf 100644 --- a/tensorflow/contrib/tpu/profiler/version.h +++ b/tensorflow/contrib/tpu/profiler/version.h @@ -16,6 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ #define TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ -#define TPU_PROFILER_VERSION "1.9.0" +#define TPU_PROFILER_VERSION "1.10.0" #endif // TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 029492b489e..f2211555681 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -45,6 +45,7 @@ from tensorflow.core.framework import variable_pb2 from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest as data_nest from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import util as estimator_util @@ -204,6 +205,12 @@ def _increase_eval_step_op(iterations_per_loop): use_locking=True) +def _extract_key_names(tensor_or_dict): + if isinstance(tensor_or_dict, dict): + return sorted(tensor_or_dict.keys()) + return [] + + class _SIGNAL(object): """Signal used to control the thread of infeed/outfeed. @@ -224,7 +231,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote `metric_fn` runs on CPU to generate metrics and `tensors` represents the `Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`. To be precise, TPU evaluation expects a slightly different signature from the - `tf.estimator.Estimator`. While `EstimatorSpec.eval_metric_ops` expects a + @{tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`. The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The `tensors` usually specify the model logits, which are transferred back from @@ -247,7 +254,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote sending tensors from TPU to CPU. To reduce the overhead, try reducing the size of the tensors. The `tensors` are concatenated along their major (batch) dimension, and so must be >= rank 1. The `host_call` is useful for writing - summaries with `tf.contrib.summary.create_file_writer`. + summaries with @{tf.contrib.summary.create_file_writer}. """ def __new__(cls, @@ -711,8 +718,7 @@ def generate_per_host_enqueue_ops_fn_for_host( features, labels = inputs.features_and_labels() signals = inputs.signals() - inputs_structure_recorder.validate_and_record_structure( - features, labels, signals) + inputs_structure_recorder.validate_and_record_structure(features, labels) unsharded_tensor_list = ( inputs_structure_recorder.flatten_features_and_labels( features, labels, signals)) @@ -859,7 +865,7 @@ def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder, signals = inputs.signals() inputs_structure_recorder.validate_and_record_structure( - features, labels, signals) + features, labels) flattened_inputs = ( inputs_structure_recorder.flatten_features_and_labels( features, labels, signals)) @@ -901,17 +907,19 @@ class _InputPipeline(object): inputs returned by the `input_fn` can have one of the following forms: 1. features 2. (features, labels) + 3. ((arbitrarily nested structure of features), labels) Internally, form 1 is reformed to `(features, None)` as features and labels are passed separately to underlying methods. For TPU training, TPUEstimator may expect multiple `features` and `labels` tuples one for each core. TPUEstimator allows various different structures for inputs (namely `features` - and `labels`). `features` can be `Tensor` or dict of string name to `Tensor`, - and `labels` could be `None`, `Tensor`, or dict of string name to `Tensor`. - TPU infeed/outfeed library expects flattened tensor list. So, `features` and - `labels` need to be flattened, before infeed enqueue, and the structure of - them needs to be recorded, in order to restore them after infeed dequeue. + and `labels`). `features` can be `Tensor`, dict of string name to `Tensor`, + or nested tuples and `labels` could be `None`, `Tensor`, or dict of string + name to `Tensor`. TPU infeed/outfeed library expects flattened tensor list. + So, `features` and `labels` need to be flattened, before infeed enqueue, and + the structure of them needs to be recorded, in order to restore them after + infeed dequeue. """ class InputsStructureRecorder(object): @@ -919,10 +927,7 @@ class _InputPipeline(object): def __init__(self, input_partition_dims=None): # Holds the structure of inputs - self._feature_names = [] - self._label_names = [] - self._has_labels = False - self._signals_helper = None + self._feature_structure = {} self._flattened_input_dims = None if input_partition_dims: @@ -949,7 +954,7 @@ class _InputPipeline(object): return self._flattened_input_dims def has_labels(self): - return self._has_labels + return 'labels' in self._feature_structure def _flatten_input_dims(self, feature_dims, feature_dims_names, label_dims, label_dims_names, label_names, has_labels): @@ -977,35 +982,16 @@ class _InputPipeline(object): return flattened_input_dims - def validate_and_record_structure(self, features, labels, signals=None): + def validate_and_record_structure(self, features, labels): """Validates and records the structure of `features` and `labels`.""" - - def _extract_key_names(tensor_or_dict): - if tensor_or_dict is None: - return [] - return sorted(tensor_or_dict.keys()) if isinstance( - tensor_or_dict, dict) else [] - # Extract structure. has_labels = labels is not None feature_names = _extract_key_names(features) label_names = _extract_key_names(labels) - if signals is not None and self._signals_helper is None: - # Record signals helper. - self._signals_helper = _SignalsHelper(signals) - - if self._initialized: - # Verify the structure is same. The following should never happen. - assert feature_names == self._feature_names, 'feature keys mismatched' - assert label_names == self._label_names, 'label keys mismatched' - assert has_labels == self._has_labels, 'label presence mismatched' - else: + if not self._initialized: # Record structure. self._initialized = True - self._feature_names = feature_names - self._label_names = label_names - self._has_labels = has_labels if self._feature_dims is not None: feature_dims_names = _extract_key_names(self._feature_dims) if feature_dims_names != feature_names: @@ -1027,24 +1013,12 @@ class _InputPipeline(object): def flatten_features_and_labels(self, features, labels, signals=None): """Flattens the `features` and `labels` to a single tensor list.""" - flattened_inputs = [] - if self._feature_names: - # We need a fixed ordering for enqueueing and dequeueing. - flattened_inputs.extend( - [features[name] for name in self._feature_names]) - else: - flattened_inputs.append(features) - + self._feature_structure['features'] = features if labels is not None: - if self._label_names: - # We need a fixed ordering for enqueueing and dequeueing. - flattened_inputs.extend([labels[name] for name in self._label_names]) - else: - flattened_inputs.append(labels) - + self._feature_structure['labels'] = labels if signals is not None: - flattened_inputs.extend(_SignalsHelper.as_tensor_list(signals)) - return flattened_inputs + self._feature_structure['signals'] = signals + return data_nest.flatten(self._feature_structure) def unflatten_features_and_labels(self, flattened_inputs): """Restores the flattened inputs to original features and labels form. @@ -1061,49 +1035,13 @@ class _InputPipeline(object): ValueError: If the number of expected tensors from `flattened_inputs` mismatches the recorded structure. """ - expected_num_features = ( - len(self._feature_names) if self._feature_names else 1) - if self._has_labels: - expected_num_labels = ( - len(self._label_names) if self._label_names else 1) - else: - expected_num_labels = 0 - expected_num_signals = ( - self._signals_helper.num_signals if self._signals_helper else 0) - - expected_num_tensors = ( - expected_num_features + expected_num_labels + expected_num_signals) - - if expected_num_tensors != len(flattened_inputs): - raise ValueError( - 'The number of flattened tensors mismatches expected num. ' - 'Expected {}, got {}'.format(expected_num_tensors, - len(flattened_inputs))) - if self._feature_names: - unflattened_features = dict( - zip(self._feature_names, flattened_inputs[:expected_num_features])) - else: - # Single tensor case - unflattened_features = flattened_inputs[0] - - if expected_num_labels == 0: - unflattened_label = None - elif self._label_names: - label_list = flattened_inputs[ - expected_num_features:expected_num_features + expected_num_labels] - unflattened_label = dict(zip(self._label_names, label_list)) - else: - # Single tensor case. - unflattened_label = flattened_inputs[expected_num_features] - - signals = None - if expected_num_signals != 0: - tensor_list_for_signals = flattened_inputs[ - expected_num_features + expected_num_labels:] - signals = self._signals_helper.unflatten(tensor_list_for_signals) - - return _Inputs(unflattened_features, unflattened_label, signals=signals) + unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure, + flattened_inputs) + return _Inputs( + unflattened_inputs['features'], + unflattened_inputs.get('labels'), + signals=unflattened_inputs.get('signals')) def __init__(self, input_fn, batch_axis, ctx): """Constructor. @@ -1505,12 +1443,14 @@ class _ModelFnWrapper(object): 'The {} to the model returned by input_fn must have static shape.' ' Tensor: {}'.format(obj_name, obj)) else: - for (key, tensor) in obj.items(): - if not tensor.get_shape().is_fully_defined(): - raise ValueError( - 'The {} to the model returned by input_fn must have static ' - 'shape. Key: \'{}\', Tensor: {}'.format( - obj_name, key, tensor)) + for (key, value) in obj.items(): + flattened_tensors = data_nest.flatten(value) + for tensor in flattened_tensors: + if not tensor.get_shape().is_fully_defined(): + raise ValueError( + 'The {} to the model returned by input_fn must have static ' + 'shape. Key: \'{}\', Tensor: {}'.format( + obj_name, key, tensor)) validate(features, 'features') if labels is not None: @@ -3338,26 +3278,6 @@ class _PaddingSignals(object): return padding_mask -class _SignalsHelper(object): - """A general helper class to handle common signals manipulation.""" - - def __init__(self, signals): - self._signal_keys = [] - for key in sorted(iter(signals.keys())): - self._signal_keys.append(key) - - @property - def num_signals(self): - return len(self._signal_keys) - - def unflatten(self, tensor_list): - return dict(zip(self._signal_keys, tensor_list)) - - @staticmethod - def as_tensor_list(signals): - return [signals[key] for key in sorted(iter(signals.keys()))] - - def _verify_cross_hosts_transfer_size(tensor_dict, message): total_size = 0 tensor_structure = {} diff --git a/tensorflow/contrib/training/python/training/training.py b/tensorflow/contrib/training/python/training/training.py index f72e0a3f831..c272a2ac144 100644 --- a/tensorflow/contrib/training/python/training/training.py +++ b/tensorflow/contrib/training/python/training/training.py @@ -484,7 +484,8 @@ def train(train_op, save_checkpoint_secs=600, save_summaries_steps=100, config=None, - max_wait_secs=7200): + max_wait_secs=7200, + run_metadata=None): """Runs the training loop. Args: @@ -511,6 +512,7 @@ def train(train_op, become available. This should be kept relatively short to help detect incorrect code, but sometimes may need to be increased if the chief takes a while to start up. + run_metadata: A [`RunMetadata`] protocol buffer. Returns: the value of the loss function after training. @@ -541,5 +543,5 @@ def train(train_op, max_wait_secs=max_wait_secs) as session: loss = None while not session.should_stop(): - loss = session.run(train_op) + loss = session.run(train_op, run_metadata=run_metadata) return loss diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 82443fd7e88..9a8c20b1fde 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -149,6 +149,7 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_proto_library") load( "//third_party/mkl:build_defs.bzl", "if_mkl", + "mkl_deps", ) exports_files(["ops/ops.pbtxt"]) @@ -735,7 +736,10 @@ cc_library( "util/reporter.h", ], copts = tf_copts(), - linkopts = ["-lm"], + linkopts = select({ + "//tensorflow:windows": [], + "//conditions:default": ["-lm"], + }), visibility = ["//visibility:public"], deps = [ ":lib", @@ -860,7 +864,6 @@ tf_cuda_library( "util/work_sharder.h", ] + select({ "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "util/memmapped_file_system.h", "util/memmapped_file_system_writer.h", @@ -2036,7 +2039,6 @@ cc_library( linkopts = select({ "//tensorflow:freebsd": [], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//tensorflow:android": [], "//conditions:default": [ "-ldl", @@ -2126,7 +2128,6 @@ cc_library( linkopts = select({ "//tensorflow:freebsd": [], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": ["-ldl"], }), deps = [ @@ -2151,7 +2152,6 @@ cc_library( linkopts = select({ "//tensorflow:freebsd": [], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": ["-ldl"], }), deps = [ @@ -2183,7 +2183,6 @@ cc_library( linkopts = select({ "//tensorflow:freebsd": [], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": ["-ldl"], }), deps = [ @@ -2489,7 +2488,6 @@ tf_cuda_library( ], ) + select({ "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "util/memmapped_file_system.cc", "util/memmapped_file_system_writer.cc", @@ -2498,13 +2496,13 @@ tf_cuda_library( hdrs = FRAMEWORK_INTERNAL_PUBLIC_HEADERS, copts = tf_copts(), linkopts = select({ - "//tensorflow:freebsd": [], + "//tensorflow:freebsd": ["-lm"], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], - "//conditions:default": ["-ldl"], - }) + [ - "-lm", - ], + "//conditions:default": [ + "-ldl", + "-lm", + ], + }), deps = [ ":lib", ":lib_internal", @@ -2519,12 +2517,7 @@ tf_cuda_library( ] + if_static( extra_deps = ["@protobuf_archive//:protobuf"], otherwise = ["@protobuf_archive//:protobuf_headers"], - ) + if_mkl( - [ - "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ], - ), + ) + mkl_deps(), alwayslink = 1, ) @@ -2806,12 +2799,7 @@ tf_cuda_library( ":protos_all_cc", "//third_party/eigen3", "//tensorflow/core/grappler:grappler_item", - ] + if_mkl( - [ - "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ], - ), + ] + mkl_deps(), alwayslink = 1, ) @@ -2851,12 +2839,7 @@ tf_cuda_library( "//tensorflow/core/grappler/optimizers:meta_optimizer", "//third_party/eigen3", "//tensorflow/core/kernels:required", - ] + if_mkl( - [ - "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ], - ) + tf_additional_core_deps() + if_static([":core_cpu_impl"]), + ] + mkl_deps() + tf_additional_core_deps() + if_static([":core_cpu_impl"]), alwayslink = 1, ) @@ -3149,7 +3132,10 @@ cc_library( testonly = 1, srcs = ["platform/test_main.cc"], copts = tf_copts(), - linkopts = ["-lm"], + linkopts = select({ + "//tensorflow:windows": [], + "//conditions:default": ["-lm"], + }), visibility = ["//tensorflow:internal"], deps = [ ":lib", @@ -3860,11 +3846,7 @@ tf_cuda_only_cc_test( ":test", ":test_main", "//third_party/eigen3", - ] + if_mkl( - [ - "//third_party/mkl:intel_binary_blob", - ], - ), + ] + mkl_deps(), ) tf_cc_test_gpu( diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index ae03a61ae66..51812caeb29 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -59,8 +59,8 @@ void GetGoldenApiDefs(Env* env, const string& api_files_dir, file_contents = PBTxtFromMultiline(file_contents); ApiDefs api_defs; - CHECK(tensorflow::protobuf::TextFormat::ParseFromString(file_contents, - &api_defs)) + QCHECK(tensorflow::protobuf::TextFormat::ParseFromString(file_contents, + &api_defs)) << "Failed to load " << file_path; CHECK_EQ(api_defs.op_size(), 1); (*name_to_api_def)[api_defs.op(0).graph_op_name()] = api_defs.op(0); diff --git a/tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt b/tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt index a0e42dd02c5..9f3f9b276b4 100644 --- a/tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt @@ -123,5 +123,7 @@ Batched indexing into a 3-tensor: [['a1', 'b1'], ['c1', 'd1']]] output = [['b0', 'b1'], ['d0', 'c1']] ``` + +See also `tf.gather` and `tf.batch_gather`. END } diff --git a/tensorflow/core/api_def/base_api/api_def_GatherV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_GatherV2.pbtxt index 162ef2b033e..c6104da4a64 100644 --- a/tensorflow/core/api_def/base_api/api_def_GatherV2.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_GatherV2.pbtxt @@ -54,5 +54,7 @@ params.shape[axis + 1:]` where: Note that on CPU, if an out of bound index is found, an error is returned. On GPU, if an out of bound index is found, a 0 is stored in the corresponding output value. + +See also `tf.batch_gather` and `tf.gather_nd`. END } diff --git a/tensorflow/core/api_def/base_api/api_def_HostConst.pbtxt b/tensorflow/core/api_def/base_api/api_def_HostConst.pbtxt new file mode 100644 index 00000000000..9d04a01f6fc --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_HostConst.pbtxt @@ -0,0 +1,11 @@ +op { + graph_op_name: "HostConst" + attr { + name: "value" + description: <